In [ ]:
import os
import time
import shutil
import numpy as np
import rasterio
from rasterio.windows import from_bounds
import cv2
import ee
import geemap
from google.colab import drive
from sklearn.model_selection import train_test_split
!pip install geedim
!pip install geemap
# 1. SETUP
drive.mount('/content/drive', force_remount=True)
MY_PROJECT_ID = '[REDACTED_FOR_SECURITY]'
# Authentication Block
try:
ee.Initialize(project=MY_PROJECT_ID)
print(f" Earth Engine Initialized with YOUR project: {MY_PROJECT_ID}")
except:
ee.Authenticate()
ee.Initialize(project=MY_PROJECT_ID)
print(f" Earth Engine Authenticated and Initialized with YOUR project: {MY_PROJECT_ID}")
ASSET_ID = '[REDACTED_FOR_SECURITY]'
SAVE_DIR = '/content/drive/MyDrive/Prithvi_Scratch_Results_1/'
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)
PATCH_SIZE = 224
S2_SCALE = 5000.0
# Prithvi 3-Step Temporal Configuration
TIME_WINDOWS = [
('2024-10-15', '2024-11-15'), # T1: Sowing
('2025-01-01', '2025-01-31'), # T2: Peak Growth 1
('2025-02-15', '2025-03-15') # T3: Peak Growth 2
]
def generate_prithvi_npy():
print(f" Starting Prithvi Data Generation.")
print(f" Billing Project: {MY_PROJECT_ID}")
print(f" Reading Asset: {ASSET_ID}")
# A. Mask Ingestion
mask_img = ee.Image(ASSET_ID)
roi_geom = mask_img.geometry()
mask_file = 'local_mask_prithvi.tif'
if not os.path.exists(mask_file):
print(" Downloading Mask...")
# geemap will now use MY_PROJECT_ID to authorize the download
geemap.download_ee_image(mask_img, mask_file, region=roi_geom, scale=10, crs='EPSG:4326', overwrite=True)
# B. Define Spatial Subset
with rasterio.open(mask_file) as src:
b = src.bounds
cx, cy = (b.left + b.right)/2, (b.bottom + b.top)/2
offset = 0.04
window = from_bounds(cx-offset, cy-offset, cx+offset, cy+offset, src.transform)
mask = src.read(1, window=window)
mask = np.where(mask > 0, 1.0, 0.0).astype(np.float32)
target_h, target_w = mask.shape
small_roi = ee.Geometry.BBox(cx-offset, cy-offset, cx+offset, cy+offset)
print(f" ROI Subset Defined: {target_h}x{target_w}")
# C. Multi-Temporal Stack Generation
stack = []
for i, (start, end) in enumerate(TIME_WINDOWS):
fname = f'prithvi_time_{i}.tif'
# Robust Download Logic
attempts = 0
while not os.path.exists(fname) and attempts < 3:
try:
print(f" Downloading Time Step {i+1}/{len(TIME_WINDOWS)}: {start} to {end}...")
# Optical Only (6 Bands)
img = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') \
.filterBounds(small_roi) \
.filterDate(start, end) \
.median() \
.select(['B2','B3','B4','B8','B11','B12'])
geemap.download_ee_image(img, fname, region=small_roi, scale=10, crs='EPSG:4326', overwrite=True)
except Exception as e:
print(f" Error downloading {fname}: {e}")
attempts += 1
time.sleep(2)
# Fallback Logic
if not os.path.exists(fname):
print(f" Failed to download {fname}. Attempting fallback...")
if i > 0:
shutil.copy(f'prithvi_time_{i-1}.tif', fname)
else:
# Create dummy zero array if first step fails
with rasterio.open(mask_file) as src:
profile = src.profile
profile.update(count=6, dtype=rasterio.float32)
with rasterio.open(fname, 'w', **profile) as dst:
dst.write(np.zeros((6, target_h, target_w), dtype=np.float32))
# Processing & Normalization
with rasterio.open(fname) as src:
arr = src.read() # (6, H, W)
arr = np.transpose(arr, (1, 2, 0)) # (H, W, 6)
if arr.shape[:2] != (target_h, target_w):
arr = cv2.resize(arr, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
# Normalization
arr = np.clip(arr / S2_SCALE, 0, 1).astype(np.float32)
stack.append(arr)
# D. Stack Construction
full_cube = np.stack(stack, axis=2)
# E. Tiling
x_out, y_out = [], []
stride = PATCH_SIZE
print(" Tiling patches...")
for y in range(0, target_h, stride):
for x in range(0, target_w, stride):
img_p = full_cube[y:y+stride, x:x+stride]
mask_p = mask[y:y+stride, x:x+stride]
if img_p.shape[0] != PATCH_SIZE or img_p.shape[1] != PATCH_SIZE: continue
if np.mean(mask_p) < 0.01: continue
if np.isnan(img_p).any(): continue
x_out.append(img_p)
y_out.append(mask_p)
if len(x_out) == 0: raise ValueError("No valid patches generated.")
# F. Transpose
X = np.array(x_out, dtype=np.float32).transpose(0, 4, 3, 1, 2)
y = np.array(y_out, dtype=np.float32)[:, None, :, :]
print(f" Dataset Generated. Shape: {X.shape}")
# G. Save to Drive
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
print(" Saving .npy files to Google Drive...")
np.save(os.path.join(SAVE_DIR, 'train_x.npy'), X_train)
np.save(os.path.join(SAVE_DIR, 'train_y.npy'), y_train)
np.save(os.path.join(SAVE_DIR, 'val_x.npy'), X_val)
np.save(os.path.join(SAVE_DIR, 'val_y.npy'), y_val)
print(" DONE! Proceed to Training Cell.")
# Execute
generate_prithvi_npy()
Collecting geedim Downloading geedim-2.0.0-py3-none-any.whl.metadata (6.0 kB) Requirement already satisfied: numpy>=1.19 in /usr/local/lib/python3.12/dist-packages (from geedim) (2.0.2) Requirement already satisfied: rasterio>=1.3.8 in /usr/local/lib/python3.12/dist-packages (from geedim) (1.5.0) Requirement already satisfied: click>=8 in /usr/local/lib/python3.12/dist-packages (from geedim) (8.3.1) Requirement already satisfied: tqdm>=4.6 in /usr/local/lib/python3.12/dist-packages (from geedim) (4.67.1) Requirement already satisfied: earthengine-api>=0.1.379 in /usr/local/lib/python3.12/dist-packages (from geedim) (1.5.24) Requirement already satisfied: tabulate>=0.9 in /usr/local/lib/python3.12/dist-packages (from geedim) (0.9.0) Requirement already satisfied: fsspec>=2025.2 in /usr/local/lib/python3.12/dist-packages (from geedim) (2025.3.0) Requirement already satisfied: aiohttp>=3.11 in /usr/local/lib/python3.12/dist-packages (from geedim) (3.13.3) Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (2.6.1) Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.4.0) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (25.4.0) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.8.0) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (6.7.0) Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (0.4.1) Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.22.0) Requirement already satisfied: google-cloud-storage in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (3.7.0) Requirement already satisfied: google-api-python-client>=1.12.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.187.0) Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.43.0) Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (0.3.0) Requirement already satisfied: httplib2<1dev,>=0.9.2 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (0.31.0) Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.32.4) Requirement already satisfied: affine in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (2.4.0) Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (2026.1.4) Requirement already satisfied: cligj>=0.5 in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (0.7.2) Requirement already satisfied: pyparsing in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (3.3.1) Requirement already satisfied: typing-extensions>=4.2 in /usr/local/lib/python3.12/dist-packages (from aiosignal>=1.4.0->aiohttp>=3.11->geedim) (4.15.0) Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (2.29.0) Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (4.2.0) Requirement already satisfied: cachetools<7.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (6.2.4) Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (0.4.2) Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (4.9.1) Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.12/dist-packages (from yarl<2.0,>=1.17.0->aiohttp>=3.11->geedim) (3.11) Requirement already satisfied: google-cloud-core<3.0.0,>=2.4.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (2.5.0) Requirement already satisfied: google-resumable-media<3.0.0,>=2.7.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (2.8.0) Requirement already satisfied: google-crc32c<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (1.8.0) Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=0.1.379->geedim) (3.4.4) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=0.1.379->geedim) (2.5.0) Requirement already satisfied: googleapis-common-protos<2.0.0,>=1.56.2 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (1.72.0) Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<7.0.0,>=3.19.5 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (5.29.5) Requirement already satisfied: proto-plus<2.0.0,>=1.22.3 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (1.27.0) Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.12/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (0.6.1) Downloading geedim-2.0.0-py3-none-any.whl (73 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 73.1/73.1 kB 5.3 MB/s eta 0:00:00 Installing collected packages: geedim Successfully installed geedim-2.0.0 Requirement already satisfied: geemap in /usr/local/lib/python3.12/dist-packages (0.35.3) Requirement already satisfied: bqplot in /usr/local/lib/python3.12/dist-packages (from geemap) (0.12.45) Requirement already satisfied: colour in /usr/local/lib/python3.12/dist-packages (from geemap) (0.1.5) Requirement already satisfied: earthengine-api>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (1.5.24) Requirement already satisfied: eerepr>=0.1.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.1.2) Requirement already satisfied: folium>=0.17.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.20.0) Requirement already satisfied: geocoder in /usr/local/lib/python3.12/dist-packages (from geemap) (1.38.1) Requirement already satisfied: ipyevents in /usr/local/lib/python3.12/dist-packages (from geemap) (2.0.4) Requirement already satisfied: ipyfilechooser>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.6.0) Requirement already satisfied: ipyleaflet>=0.19.2 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.20.0) Requirement already satisfied: ipytree in /usr/local/lib/python3.12/dist-packages (from geemap) (0.2.2) Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (from geemap) (3.10.0) Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from geemap) (2.0.2) Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from geemap) (2.2.2) Requirement already satisfied: plotly in /usr/local/lib/python3.12/dist-packages (from geemap) (5.24.1) Requirement already satisfied: pyperclip in /usr/local/lib/python3.12/dist-packages (from geemap) (1.11.0) Requirement already satisfied: pyshp>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from geemap) (3.0.3) Requirement already satisfied: python-box in /usr/local/lib/python3.12/dist-packages (from geemap) (7.3.2) Requirement already satisfied: scooby in /usr/local/lib/python3.12/dist-packages (from geemap) (0.11.0) Requirement already satisfied: google-cloud-storage in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (3.7.0) Requirement already satisfied: google-api-python-client>=1.12.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.187.0) Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.43.0) Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (0.3.0) Requirement already satisfied: httplib2<1dev,>=0.9.2 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (0.31.0) Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.32.4) Requirement already satisfied: branca>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (0.8.2) Requirement already satisfied: jinja2>=2.9 in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (3.1.6) Requirement already satisfied: xyzservices in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (2025.11.0) Requirement already satisfied: ipywidgets in /usr/local/lib/python3.12/dist-packages (from ipyfilechooser>=0.6.0->geemap) (7.7.1) Requirement already satisfied: jupyter-leaflet<0.21,>=0.20 in /usr/local/lib/python3.12/dist-packages (from ipyleaflet>=0.19.2->geemap) (0.20.0) Requirement already satisfied: traittypes<3,>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from ipyleaflet>=0.19.2->geemap) (0.2.3) Requirement already satisfied: traitlets>=4.3.0 in /usr/local/lib/python3.12/dist-packages (from bqplot->geemap) (5.7.1) Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2025.2) Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2025.3) Requirement already satisfied: click in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (8.3.1) Requirement already satisfied: future in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (1.0.0) Requirement already satisfied: ratelim in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (0.1.6) Requirement already satisfied: six in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (1.17.0) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (1.3.3) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (4.61.1) Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (1.4.9) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (25.0) Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (11.3.0) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (3.3.1) Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.12/dist-packages (from plotly->geemap) (9.1.2) Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (2.29.0) Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (4.2.0) Requirement already satisfied: cachetools<7.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (6.2.4) Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (0.4.2) Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (4.9.1) Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.17.1) Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.0) Requirement already satisfied: widgetsnbextension~=3.6.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.6.10) Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.34.0) Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.16) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2>=2.9->folium>=0.17.0->geemap) (3.0.3) Requirement already satisfied: google-cloud-core<3.0.0,>=2.4.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (2.5.0) Requirement already satisfied: google-resumable-media<3.0.0,>=2.7.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (2.8.0) Requirement already satisfied: google-crc32c<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (1.8.0) Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (3.4.4) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (3.11) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (2.5.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (2026.1.4) Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from ratelim->geocoder->geemap) (4.4.2) Requirement already satisfied: googleapis-common-protos<2.0.0,>=1.56.2 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (1.72.0) Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<7.0.0,>=3.19.5 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (5.29.5) Requirement already satisfied: proto-plus<2.0.0,>=1.22.3 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (1.27.0) Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.8.15) Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.4.9) Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.1) Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.6.0) Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.9.5) Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (26.2.1) Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.5.1) Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (75.2.0) Collecting jedi>=0.16 (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB) Requirement already satisfied: pickleshare in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.5) Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.52) Requirement already satisfied: pygments in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.19.2) Requirement already satisfied: backcall in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.0) Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.9.0) Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.12/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (0.6.1) Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.12/dist-packages (from widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.5.7) Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.12/dist-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.8.5) Requirement already satisfied: entrypoints in /usr/local/lib/python3.12/dist-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.4) Requirement already satisfied: jupyter-core>=4.9.2 in /usr/local/lib/python3.12/dist-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.9.1) Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.1.0) Requirement already satisfied: nbformat in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.10.4) Requirement already satisfied: nbconvert>=5 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.16.6) Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.0.0) Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.18.1) Requirement already satisfied: prometheus-client in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.23.1) Requirement already satisfied: nbclassic>=0.4.7 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.3) Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.12/dist-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.0) Requirement already satisfied: wcwidth in /usr/local/lib/python3.12/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.14) Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.12/dist-packages (from jupyter-core>=4.9.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.5.1) Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.12/dist-packages (from nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.4) Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.13.5) Requirement already satisfied: bleach!=5.0.0 in /usr/local/lib/python3.12/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.3.0) Requirement already satisfied: defusedxml in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.1) Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.3.0) Requirement already satisfied: mistune<4,>=2.0.3 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.2.0) Requirement already satisfied: nbclient>=0.5.0 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.10.4) Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.5.1) Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.12/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.21.2) Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.12/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.26.0) Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.12/dist-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.1.0) Requirement already satisfied: webencodings in /usr/local/lib/python3.12/dist-packages (from bleach!=5.0.0->bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.5.1) Requirement already satisfied: tinycss2<1.5,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.4.0) Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.4.0) Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2025.9.1) Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.37.0) Requirement already satisfied: rpds-py>=0.25.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.30.0) Requirement already satisfied: jupyter-server<3,>=1.8 in /usr/local/lib/python3.12/dist-packages (from notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.14.0) Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.0.0) Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.12/dist-packages (from beautifulsoup4->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.8.1) Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.12/dist-packages (from beautifulsoup4->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.15.0) Requirement already satisfied: pycparser in /usr/local/lib/python3.12/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.23) Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.12.1) Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.12.0) Requirement already satisfied: jupyter-server-terminals>=0.4.4 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.5.3) Requirement already satisfied: overrides>=5.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.7.0) Requirement already satisfied: websocket-client>=1.7 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.9.0) Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.0.0) Requirement already satisfied: pyyaml>=5.3 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.0.3) Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.1.4) Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.1.1) Requirement already satisfied: fqdn in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.5.1) Requirement already satisfied: isoduration in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (20.11.0) Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.0) Requirement already satisfied: rfc3987-syntax>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.1.0) Requirement already satisfied: uri-template in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.0) Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.10.0) Requirement already satisfied: lark>=1.2.2 in /usr/local/lib/python3.12/dist-packages (from rfc3987-syntax>=1.1.0->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.1) Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.12/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.4.0) Downloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 86.8 MB/s eta 0:00:00 Installing collected packages: jedi Successfully installed jedi-0.19.2 Mounted at /content/drive Earth Engine Authenticated and Initialized with YOUR project: local-dialect-484618-b9 Starting Prithvi Data Generation. Billing Project: local-dialect-484618-b9 Reading Asset: projects/satmae-2026/assets/Punjab_Mask_2024_NEW Downloading Mask...
/usr/local/lib/python3.12/dist-packages/geemap/common.py:12471: FutureWarning: 'BaseImage' is deprecated and will be removed in a future release. Please use the 'ee.Image.gd' accessor instead. img = gd.download.BaseImage(image)
...tmae-2026/assets/Punjab_Mask_2024_NEW: 0%| |0/585 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 /usr/local/lib/python3.12/dist-packages/geedim/image.py:254: RuntimeWarning: Couldn't find STAC entry for: 'projects/satmae-2026/assets/Punjab_Mask_2024_NEW'. return STACClient().get(self.id)
ROI Subset Defined: 891x891 Downloading Time Step 1/3: 2024-10-15 to 2024-11-15...
0%| |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 /usr/local/lib/python3.12/dist-packages/geedim/image.py:254: RuntimeWarning: Couldn't find STAC entry for: 'None'. return STACClient().get(self.id)
Downloading Time Step 2/3: 2025-01-01 to 2025-01-31...
0%| |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
Downloading Time Step 3/3: 2025-02-15 to 2025-03-15...
0%| |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
Tiling patches... Dataset Generated. Shape: (9, 6, 3, 224, 224) Saving .npy files to Google Drive... DONE! Proceed to Training Cell.
In [ ]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
SAVE_DIR = '/content/drive/MyDrive/Prithvi_Scratch_Results_1/'
CONFIG = {
"EPOCHS": 5000,
"PATIENCE": 200,
"BATCH_SIZE": 8,
"LEARNING_RATE": 1e-4,
"WEIGHT_DECAY": 0.05,
"WARMUP_EPOCHS": 20, # Linear Warmup
"SWA_START_EPOCH": None, # Dynamic
"RESUME": True,
"DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
"IMG_SIZE": 224,
"NUM_FRAMES": 3,
"IN_CHANS": 6,
"EMBED_DIM": 768,
"DEPTH": 12,
"NUM_HEADS": 12,
"SAVE_DIR": SAVE_DIR
}
class PunjabWheatDataset(Dataset):
def __init__(self, x_path, y_path):
self.data = np.load(x_path, mmap_mode='r')
self.masks = np.load(y_path, mmap_mode='r')
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img = self.data[idx]
mask = self.masks[idx]
return torch.from_numpy(img.copy()).float(), torch.from_numpy(mask.copy()).float()
In [ ]:
!pip install segmentation-models-pytorch
Collecting segmentation-models-pytorch Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB) Requirement already satisfied: huggingface-hub>=0.24 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (0.36.0) Requirement already satisfied: numpy>=1.19.3 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (2.0.2) Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (11.3.0) Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (0.7.0) Requirement already satisfied: timm>=0.9 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (1.0.24) Requirement already satisfied: torch>=1.8 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (2.9.0+cu126) Requirement already satisfied: torchvision>=0.9 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (0.24.0+cu126) Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (4.67.1) Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (3.20.2) Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (2025.3.0) Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (25.0) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (6.0.3) Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (2.32.4) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (4.15.0) Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (1.2.0) Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (75.2.0) Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (1.14.0) Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.6.1) Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.1.6) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.77) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.77) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.80) Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (9.10.2.21) Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.4.1) Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (11.3.0.4) Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (10.3.7.77) Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (11.7.1.2) Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.5.4.2) Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (0.7.1) Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (2.27.5) Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.3.20) Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.77) Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.85) Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (1.11.1.6) Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.5.0) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.8->segmentation-models-pytorch) (1.3.0) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=1.8->segmentation-models-pytorch) (3.0.3) Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (3.4.4) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (3.11) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (2.5.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (2026.1.4) Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl (154 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 154.8/154.8 kB 12.7 MB/s eta 0:00:00 Installing collected packages: segmentation-models-pytorch Successfully installed segmentation-models-pytorch-0.5.0
In [ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.ndimage import distance_transform_edt as distance
import segmentation_models_pytorch as smp
# --- 1. HEAVY TRAINING AUGMENTATION ---
def apply_training_augmentation(x, y):
# Random Horizontal Flip
if np.random.rand() > 0.5:
x = torch.flip(x, [4])
y = torch.flip(y, [3])
# Random Vertical Flip
if np.random.rand() > 0.5:
x = torch.flip(x, [3])
y = torch.flip(y, [2])
# Random Rotation (0, 90, 180, 270)
k = np.random.randint(0, 4)
x = torch.rot90(x, k, [3, 4])
y = torch.rot90(y, k, [2, 3])
# Random Intensity Scale (Brightness/Contrast jitter)
if np.random.rand() > 0.5:
noise = (torch.rand(x.shape[0], 1, 1, 1, 1, device=x.device) * 0.2) + 0.9
x = x * noise
# --- FIX: Force contiguous memory layout ---
return x.contiguous(), y.contiguous()
# --- 2. VALIDATION TTA ---
class TestTimeAugmentation(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
pred_orig = self.model(x)
x_hflip = torch.flip(x, [4])
pred_hflip = torch.flip(self.model(x_hflip), [3])
x_vflip = torch.flip(x, [3])
pred_vflip = torch.flip(self.model(x_vflip), [2])
x_rot = torch.rot90(x, 1, [3, 4])
pred_rot = torch.rot90(self.model(x_rot), -1, [2, 3])
return torch.stack([pred_orig, pred_hflip, pred_vflip, pred_rot]).mean(dim=0)
def validate_with_tta(model, loader, criterion, device):
model.eval()
total_loss = 0
tta_model = TestTimeAugmentation(model)
with torch.no_grad():
for x, y in loader:
x, y = x.to(device), y.to(device)
preds = tta_model(x)
loss = criterion(preds, y)
total_loss += loss.item()
return total_loss / len(loader)
# --- 3. LOSS FUNCTIONS (Dice 0.7 + Hausdorff 0.3) ---
class HausdorffDTLoss(nn.Module):
def __init__(self, alpha=2.0):
super().__init__()
self.alpha = alpha
def forward(self, pred, gt):
# Hausdorff Distance Transform (CPU based)
with torch.no_grad():
gt_np = gt.cpu().numpy()
dist_map = np.zeros_like(gt_np)
for i in range(len(gt_np)):
mask = (gt_np[i, 0] > 0.5).astype(np.uint8)
if mask.sum() == 0:
# If empty mask, distance is max everywhere
dist_map[i, 0] = np.ones_like(mask) * 100.0
continue
d_in = distance(mask)
d_out = distance(1 - mask)
dist_map[i, 0] = (d_out - d_in)
dist_map = torch.tensor(dist_map, device=pred.device, dtype=torch.float32)
probs = torch.sigmoid(pred)
# Weighted MSE based on distance map
return torch.mean((probs - gt) ** 2 * (1 + self.alpha * torch.abs(dist_map)))
class CompoundLoss(nn.Module):
def __init__(self):
super().__init__()
self.dice = smp.losses.DiceLoss(mode='binary', from_logits=True)
self.hausdorff = HausdorffDTLoss(alpha=2.0)
def forward(self, p, t):
# 0.7 Dice + 0.3 Hausdorff
return 0.7 * self.dice(p, t) + 0.3 * self.hausdorff(p, t)
In [ ]:
import numpy as np
import torch
import torch.nn as nn
# CORRECT 3D POSITIONAL EMBEDDING
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega
pos = pos.reshape(-1)
out = np.einsum('m,d->md', pos, omega)
emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=1)
return emb
def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size):
# 1. Height Embedding
grid_h = np.arange(grid_size, dtype=np.float32)
pos_embed_h = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_h)
# 2. Width Embedding
grid_w = np.arange(grid_size, dtype=np.float32)
pos_embed_w = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_w)
# 3. Time Embedding
grid_t = np.arange(t_size, dtype=np.float32)
pos_embed_t = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
# 4. Broadcast and Sum: Final = T + H + W
pos_t = pos_embed_t[:, np.newaxis, np.newaxis, :]
pos_h = pos_embed_h[np.newaxis, :, np.newaxis, :]
pos_w = pos_embed_w[np.newaxis, np.newaxis, :, :]
pos_embed = pos_t + pos_h + pos_w
# Flatten to (T*H*W, D) to match Transformer input
pos_embed = pos_embed.reshape(-1, embed_dim)
return torch.from_numpy(pos_embed).float().unsqueeze(0)
class PatchEmbed3D(nn.Module):
def __init__(self, patch_size=16, frames=3, in_chans=6, embed_dim=768):
super().__init__()
self.proj = nn.Conv3d(
in_chans, embed_dim,
kernel_size=(1, patch_size, patch_size),
stride=(1, patch_size, patch_size)
)
def forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class PrithviScratch(nn.Module):
def __init__(self):
super().__init__()
self.patch_embed = PatchEmbed3D(
frames=CONFIG["NUM_FRAMES"],
in_chans=CONFIG["IN_CHANS"],
embed_dim=CONFIG["EMBED_DIM"]
)
self.register_buffer(
"pos_embed",
get_3d_sincos_pos_embed(
CONFIG["EMBED_DIM"],
CONFIG["IMG_SIZE"] // 16,
CONFIG["NUM_FRAMES"]
)
)
encoder_layer = nn.TransformerEncoderLayer(
d_model=CONFIG["EMBED_DIM"], nhead=CONFIG["NUM_HEADS"],
dim_feedforward=CONFIG["EMBED_DIM"]*4, dropout=0.1,
activation='gelu', batch_first=True, norm_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=CONFIG["DEPTH"])
in_dim = CONFIG["EMBED_DIM"] * CONFIG["NUM_FRAMES"]
self.decoder = nn.Sequential(
nn.Conv2d(in_dim, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(256, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(64, 1, kernel_size=1)
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.encoder(x)
H_p = CONFIG["IMG_SIZE"] // 16
x = x.transpose(1, 2).view(B, CONFIG["EMBED_DIM"], CONFIG["NUM_FRAMES"], H_p, H_p)
x = x.reshape(B, -1, H_p, H_p)
return self.decoder(x)
In [ ]:
import os
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
train_x = os.path.join(SAVE_DIR, 'train_x.npy')
train_y = os.path.join(SAVE_DIR, 'train_y.npy')
val_x = os.path.join(SAVE_DIR, 'val_x.npy')
val_y = os.path.join(SAVE_DIR, 'val_y.npy')
if not os.path.exists(train_x):
raise FileNotFoundError(f"Data not found in {SAVE_DIR}. Run Cell 1 first.")
print("Loading Data...")
train_loader = DataLoader(PunjabWheatDataset(train_x, train_y), batch_size=CONFIG["BATCH_SIZE"], shuffle=True, num_workers=2)
val_loader = DataLoader(PunjabWheatDataset(val_x, val_y), batch_size=CONFIG["BATCH_SIZE"], shuffle=False, num_workers=2)
model = PrithviScratch().to(CONFIG["DEVICE"])
criterion = CompoundLoss().to(CONFIG["DEVICE"])
optimizer = optim.AdamW(model.parameters(), lr=CONFIG["LEARNING_RATE"], weight_decay=CONFIG["WEIGHT_DECAY"])
# --- SCHEDULER: WARMUP + COSINE DECAY ---
scheduler_warmup = LinearLR(optimizer, start_factor=0.01, total_iters=CONFIG["WARMUP_EPOCHS"])
scheduler_cosine = CosineAnnealingLR(optimizer, T_max=CONFIG["EPOCHS"] - CONFIG["WARMUP_EPOCHS"], eta_min=1e-6)
scheduler = SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_cosine], milestones=[CONFIG["WARMUP_EPOCHS"]])
# --- DYNAMIC SWA ---
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)
swa_active = False
checkpoint_path = os.path.join(CONFIG["SAVE_DIR"], "checkpoint.pth")
start_epoch = 0
best_loss = float('inf')
patience_counter = 0
history = {'train_loss': [], 'val_loss': []}
if CONFIG["RESUME"] and os.path.exists(checkpoint_path):
print("Found checkpoint. Attempting to load...")
try:
ckpt = torch.load(checkpoint_path)
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
start_epoch = ckpt['epoch'] + 1
best_loss = ckpt.get('best_loss', float('inf'))
history = ckpt.get('history', {'train_loss': [], 'val_loss': []})
if 'swa_model' in ckpt:
swa_model.load_state_dict(ckpt['swa_model'])
if 'swa_active' in ckpt:
swa_active = ckpt['swa_active']
print(f"Successfully resumed from Epoch {start_epoch} (SWA: {swa_active})")
except Exception as e:
print(f"Checkpoint corrupted ({e}). Starting from Epoch 0.")
start_epoch = 0
print("Starting Training...")
for epoch in range(start_epoch, CONFIG["EPOCHS"]):
model.train()
train_loss = 0
for x, y in train_loader:
x, y = x.to(CONFIG["DEVICE"]), y.to(CONFIG["DEVICE"])
# Heavy Augmentation
x, y = apply_training_augmentation(x, y)
optimizer.zero_grad()
preds = model(x)
# FIX: Ensure tensors are contiguous before Loss
loss = criterion(preds, y.contiguous())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item()
avg_train_loss = train_loss / len(train_loader)
# Validation with TTA
avg_val_loss = validate_with_tta(model, val_loader, criterion, CONFIG["DEVICE"])
history['train_loss'].append(avg_train_loss)
history['val_loss'].append(avg_val_loss)
# --- DYNAMIC SWA LOGIC ---
SWA_TRIGGER = CONFIG["PATIENCE"] // 2
if patience_counter >= SWA_TRIGGER and not swa_active:
print(f"Stagnation Detected (Patience {patience_counter}). Activating SWA.")
swa_active = True
swa_model.update_parameters(model)
if swa_active:
swa_model.update_parameters(model)
swa_scheduler.step()
else:
scheduler.step()
checkpoint_dict = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_loss': best_loss,
'swa_model': swa_model.state_dict(),
'history': history,
'swa_active': swa_active
}
torch.save(checkpoint_dict, checkpoint_path)
if epoch % 10 == 0:
backup_path = os.path.join(CONFIG["SAVE_DIR"], f"checkpoint_epoch_{epoch}.pth")
torch.save(checkpoint_dict, backup_path)
if avg_val_loss < best_loss:
best_loss = avg_val_loss
torch.save(model.state_dict(), os.path.join(CONFIG["SAVE_DIR"], "best_model.pth"))
patience_counter = 0
print(f"New Best Model! Loss: {best_loss:.4f}")
else:
patience_counter += 1
if patience_counter >= CONFIG["PATIENCE"]:
print(f"Early Stopping at Epoch {epoch}")
if swa_active:
torch.save(swa_model.state_dict(), os.path.join(CONFIG["SAVE_DIR"], "swa_model_final.pth"))
break
if epoch % 5 == 0:
mode = "SWA" if swa_active else "STD"
print(f"Ep {epoch} [{mode}] | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}")
Loading Data... Starting Training... New Best Model! Loss: 1.4913 Ep 0 [STD] | Train: 1.4160 | Val: 1.4913 New Best Model! Loss: 1.2440 New Best Model! Loss: 0.8907 Ep 5 [STD] | Train: 1.0207 | Val: 0.8907 New Best Model! Loss: 0.8533 Ep 10 [STD] | Train: 0.7532 | Val: 0.8762 Ep 15 [STD] | Train: 0.7301 | Val: 0.8679 New Best Model! Loss: 0.8505 New Best Model! Loss: 0.7767 New Best Model! Loss: 0.6073 Ep 20 [STD] | Train: 0.7010 | Val: 0.6073 Ep 25 [STD] | Train: 0.4928 | Val: 0.8917 Ep 30 [STD] | Train: 0.4096 | Val: 0.8945 New Best Model! Loss: 0.5724 New Best Model! Loss: 0.5573 Ep 35 [STD] | Train: 0.3281 | Val: 0.5573 New Best Model! Loss: 0.5204 New Best Model! Loss: 0.4516 New Best Model! Loss: 0.3923 New Best Model! Loss: 0.3576 New Best Model! Loss: 0.3575 Ep 40 [STD] | Train: 0.3195 | Val: 0.3575 New Best Model! Loss: 0.3457 New Best Model! Loss: 0.3014 New Best Model! Loss: 0.2913 Ep 45 [STD] | Train: 0.2988 | Val: 0.2972 Ep 50 [STD] | Train: 0.2828 | Val: 0.3110 New Best Model! Loss: 0.2756 New Best Model! Loss: 0.2655 New Best Model! Loss: 0.2621 New Best Model! Loss: 0.2613 Ep 55 [STD] | Train: 0.2701 | Val: 0.2613 Ep 60 [STD] | Train: 0.2748 | Val: 0.2653 New Best Model! Loss: 0.2567 New Best Model! Loss: 0.2563 Ep 65 [STD] | Train: 0.2697 | Val: 0.2563 New Best Model! Loss: 0.2558 Ep 70 [STD] | Train: 0.2699 | Val: 0.2564 New Best Model! Loss: 0.2539 Ep 75 [STD] | Train: 0.2644 | Val: 0.2539 Ep 80 [STD] | Train: 0.2602 | Val: 0.2589 New Best Model! Loss: 0.2536 New Best Model! Loss: 0.2522 New Best Model! Loss: 0.2517 New Best Model! Loss: 0.2502 Ep 85 [STD] | Train: 0.2627 | Val: 0.2502 New Best Model! Loss: 0.2487 Ep 90 [STD] | Train: 0.2536 | Val: 0.2597 Ep 95 [STD] | Train: 0.2594 | Val: 0.2517 New Best Model! Loss: 0.2474 Ep 100 [STD] | Train: 0.2469 | Val: 0.2514 Ep 105 [STD] | Train: 0.2486 | Val: 0.2518 New Best Model! Loss: 0.2471 New Best Model! Loss: 0.2461 New Best Model! Loss: 0.2449 New Best Model! Loss: 0.2426 Ep 110 [STD] | Train: 0.2565 | Val: 0.2426 New Best Model! Loss: 0.2423 Ep 115 [STD] | Train: 0.2402 | Val: 0.2498 Ep 120 [STD] | Train: 0.2469 | Val: 0.2491 Ep 125 [STD] | Train: 0.2452 | Val: 0.2521 Ep 130 [STD] | Train: 0.2408 | Val: 0.2504 Ep 135 [STD] | Train: 0.2365 | Val: 0.2515 Ep 140 [STD] | Train: 0.2506 | Val: 0.2509 Ep 145 [STD] | Train: 0.2356 | Val: 0.2450 Ep 150 [STD] | Train: 0.2304 | Val: 0.2486 Ep 155 [STD] | Train: 0.2301 | Val: 0.2494 Ep 160 [STD] | Train: 0.2270 | Val: 0.2480 Ep 165 [STD] | Train: 0.2268 | Val: 0.2476 Ep 170 [STD] | Train: 0.2314 | Val: 0.2550 Ep 175 [STD] | Train: 0.2247 | Val: 0.2519 Ep 180 [STD] | Train: 0.2059 | Val: 0.2533 Ep 185 [STD] | Train: 0.2057 | Val: 0.2474 Ep 190 [STD] | Train: 0.2229 | Val: 0.2552 Ep 195 [STD] | Train: 0.2053 | Val: 0.2520 Ep 200 [STD] | Train: 0.2148 | Val: 0.2574 Ep 205 [STD] | Train: 0.2072 | Val: 0.2529 Ep 210 [STD] | Train: 0.2096 | Val: 0.2477 Stagnation Detected (Patience 100). Activating SWA. Ep 215 [SWA] | Train: 0.2010 | Val: 0.2515 Ep 220 [SWA] | Train: 0.2073 | Val: 0.2522 Ep 225 [SWA] | Train: 0.2085 | Val: 0.2508 Ep 230 [SWA] | Train: 0.1971 | Val: 0.2531 Ep 235 [SWA] | Train: 0.2049 | Val: 0.2537 Ep 240 [SWA] | Train: 0.1915 | Val: 0.2501 Ep 245 [SWA] | Train: 0.1790 | Val: 0.2558 Ep 250 [SWA] | Train: 0.1950 | Val: 0.2504 Ep 255 [SWA] | Train: 0.1838 | Val: 0.2571 Ep 260 [SWA] | Train: 0.1806 | Val: 0.2510 Ep 265 [SWA] | Train: 0.1949 | Val: 0.2492 Ep 270 [SWA] | Train: 0.1942 | Val: 0.2477 Ep 275 [SWA] | Train: 0.1859 | Val: 0.2481 Ep 280 [SWA] | Train: 0.1840 | Val: 0.2563 Ep 285 [SWA] | Train: 0.1765 | Val: 0.2496 Ep 290 [SWA] | Train: 0.1741 | Val: 0.2524 Ep 295 [SWA] | Train: 0.1721 | Val: 0.2519 Ep 300 [SWA] | Train: 0.1744 | Val: 0.2473 Ep 305 [SWA] | Train: 0.1691 | Val: 0.2512 Ep 310 [SWA] | Train: 0.1663 | Val: 0.2504 Early Stopping at Epoch 311
In [ ]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import os
import json
from sklearn.metrics import confusion_matrix
# --- 1. CONFIGURATION ---
SAVE_DIR = '/content/drive/MyDrive/Prithvi_Scratch_Results_1/'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Auto-detect best model in the Scratch Directory
if os.path.exists(os.path.join(SAVE_DIR, "swa_model_final.pth")):
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "swa_model_final.pth")
MODEL_TYPE = "SWA (Best Generalization)"
elif os.path.exists(os.path.join(SAVE_DIR, "best_model.pth")):
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "best_model.pth")
MODEL_TYPE = "Best Valid Loss"
else:
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint.pth")
MODEL_TYPE = "Latest Epoch"
print(f" EVALUATING SCRATCH MODEL: {MODEL_TYPE}")
print(f" Path: {CHECKPOINT_PATH}")
# --- 2. MODEL DEFINITION (Prithvi Scratch Architecture) ---
class PatchEmbed3D(nn.Module):
def __init__(self, patch_size=16, frames=3, in_chans=6, embed_dim=768):
super().__init__()
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
def forward(self, x):
return self.proj(x).flatten(2).transpose(1, 2)
class PrithviScratch(nn.Module):
def __init__(self):
super().__init__()
# Standard Prithvi Config
self.patch_embed = PatchEmbed3D(patch_size=16, frames=3, in_chans=6, embed_dim=768)
# Pos Embed (Fixed Size)
self.register_buffer("pos_embed", torch.zeros(1, 3 * (224//16)**2, 768))
# Encoder (12 Layers)
encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=12, dim_feedforward=768*4, dropout=0.1, activation='gelu', batch_first=True, norm_first=True)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=12)
# Decoder (U-Net Style)
in_dim = 768 * 3
self.decoder = nn.Sequential(
nn.Conv2d(in_dim, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(512, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(256, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(64, 1, 1)
)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
# Robust addition for Pos Embed
if self.pos_embed.shape[1] == x.shape[1]:
x = x + self.pos_embed
x = self.encoder(x)
# Reshape for Decoder
H_p = 14 # 224 / 16
x = x.transpose(1, 2).view(B, 768, 3, H_p, H_p)
x = x.reshape(B, -1, H_p, H_p)
return self.decoder(x)
# --- 3. HELPER FUNCTIONS ---
def plot_training_curves(ckpt_path):
if not os.path.exists(ckpt_path):
print(" History file not found.")
return
# Load history from checkpoint.pth (always contains full history)
hist_path = os.path.join(SAVE_DIR, "checkpoint.pth")
ckpt = torch.load(hist_path, map_location='cpu')
history = ckpt.get('history', None)
if history is None:
print(" No training history found.")
return
plt.figure(figsize=(12, 5))
plt.plot(history['train_loss'], 'b-', label='Training Loss', alpha=0.7)
plt.plot(history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
# Mark Min Val Loss
val_loss = history['val_loss']
if len(val_loss) > 0:
min_loss = min(val_loss)
min_epoch = val_loss.index(min_loss)
plt.scatter(min_epoch, min_loss, c='red', zorder=5)
plt.text(min_epoch, min_loss, f" Best: {min_loss:.4f}", verticalalignment='bottom')
plt.title('Prithvi Scratch: Training Progress')
plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.legend(); plt.grid(True, linestyle='--')
plt.show()
def visualize_predictions(model, loader, device, num_samples=3):
model.eval()
try:
x, y = next(iter(loader))
except StopIteration:
print(" Loader is empty.")
return
x, y = x.to(device), y.to(device)
with torch.no_grad():
preds = (torch.sigmoid(model(x)) > 0.5).float().cpu().numpy()
x_np, y_np = x.cpu().numpy(), y.cpu().numpy()
# SAFETY: Ensure we don't index out of bounds if batch size < num_samples
actual_samples = min(num_samples, len(x_np))
fig, axs = plt.subplots(actual_samples, 3, figsize=(15, 5 * actual_samples))
plt.suptitle("Prithvi Scratch Results: Input vs Truth vs Prediction", fontsize=16)
for i in range(actual_samples):
# Create FCC: NIR(3), Red(2), Green(1) from Time Step 1
img = np.stack([x_np[i,3,1], x_np[i,2,1], x_np[i,1,1]], axis=2)
# Normalize for Display
img = (img - img.min()) / (img.max() - img.min() + 1e-6)
# Handle axes if only 1 sample
ax = axs[i] if actual_samples > 1 else axs
ax[0].imshow(img)
ax[0].set_title("Satellite Input (NIR-R-G)")
ax[1].imshow(y_np[i,0], cmap='gray')
ax[1].set_title("Ground Truth")
ax[2].imshow(preds[i,0], cmap='gray')
ax[2].set_title("Prediction")
for a in ax: a.axis('off')
plt.tight_layout()
plt.show()
# --- 4. EXECUTION ---
# A. Plot Curves
plot_training_curves(CHECKPOINT_PATH)
# B. Load Model & Visualize
model = PrithviScratch().to(DEVICE)
print(f"Loading weights...")
if os.path.exists(CHECKPOINT_PATH):
ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
state_dict = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
# Robust Key Fixer (Handles 'blocks' vs 'encoder.layers')
new_state_dict = {}
for k, v in state_dict.items():
k = k.replace('blocks', 'encoder.layers')
if 'pos_embed' in k and v.shape != model.pos_embed.shape: continue
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
print(" Weights Loaded Successfully.")
# C. Run Viz
if 'val_loader' in locals():
visualize_predictions(model, val_loader, DEVICE, num_samples=3)
else:
print(" 'val_loader' not found. Please run the Dataset Loading cell first.")
else:
print(f" Checkpoint not found at {CHECKPOINT_PATH}")
📊 EVALUATING SCRATCH MODEL: SWA (Best Generalization) 📂 Path: /content/drive/MyDrive/Prithvi_Scratch_Results_1/swa_model_final.pth
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True warnings.warn(
Loading weights... ✅ Weights Loaded Successfully.
In [ ]:
import time
import torch
import numpy as np
import os
import json
from scipy.spatial.distance import directed_hausdorff
from sklearn.metrics import confusion_matrix
from scipy.ndimage import binary_dilation
# --- 1. CONFIGURATION ---
# Using the path defined in Cell 7 (or manually set here)
if 'CHECKPOINT_PATH' not in locals():
# Fallback if Cell 7 wasn't run
SAVE_DIR = '/content/drive/MyDrive/Prithvi_Scratch_Results_1/'
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "swa_model_final.pth")
if not os.path.exists(CHECKPOINT_PATH):
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint.pth")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f" STARTING DEEP EVALUATION")
print(f" Model: {CHECKPOINT_PATH}")
# --- 2. METRIC FUNCTIONS ---
def compute_boundary_iou(pred_mask, gt_mask, dilation=2):
"""
Computes IoU specifically along the edges of the fields.
Crucial for agricultural segmentation where boundaries matter.
"""
# Get edges by dilating and subtracting original
pred_edges = binary_dilation(pred_mask, iterations=dilation) ^ pred_mask
gt_edges = binary_dilation(gt_mask, iterations=dilation) ^ gt_mask
intersection = (pred_edges & gt_edges).sum()
union = (pred_edges | gt_edges).sum()
if union == 0: return 1.0 # Perfect match (both empty edges)
return intersection / union
def compute_hausdorff_distance(pred_mask, gt_mask):
"""
Computes 95th Percentile Hausdorff Distance.
Measures the worst-case distance between prediction and ground truth contours.
Lower is better.
"""
if pred_mask.sum() == 0 or gt_mask.sum() == 0:
return 0.0 # Handle empty masks safely
# Get coordinates of all 'True' pixels
pred_coords = np.argwhere(pred_mask)
gt_coords = np.argwhere(gt_mask)
# Calculate directed distances
d_forward = directed_hausdorff(pred_coords, gt_coords)[0]
d_backward = directed_hausdorff(gt_coords, pred_coords)[0]
return max(d_forward, d_backward)
# --- 3. EVALUATION LOOP ---
def evaluate_model_depth(loader, model_path):
# A. Load Model
# Ensure PrithviScratch class is defined (from Cell 7)
model = PrithviScratch().to(DEVICE)
if not os.path.exists(model_path):
print(f" Error: Model file not found at {model_path}")
return None
ckpt = torch.load(model_path, map_location=DEVICE)
state_dict = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
# Auto-Fix Keys
new_state_dict = {}
for k, v in state_dict.items():
k = k.replace('blocks', 'encoder.layers')
if 'pos_embed' in k and v.shape != model.pos_embed.shape: continue
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
model.eval()
# B. Initialize Accumulators
tp_total, fp_total, fn_total, tn_total = 0, 0, 0, 0
boundary_ious = []
hausdorff_dists = []
inference_times = []
print(" Processing Validation Set...", end="")
with torch.no_grad():
for i, (x, y) in enumerate(loader):
x = x.to(DEVICE)
y_true = y.cpu().numpy() # (B, 1, H, W)
# 1. Inference Speed Test
start_time = time.time()
preds = model(x)
if torch.cuda.is_available(): torch.cuda.synchronize()
end_time = time.time()
inference_times.append((end_time - start_time) / x.shape[0]) # Time per image
# 2. Thresholding
probs = torch.sigmoid(preds)
y_pred = (probs > 0.5).float().cpu().numpy()
# 3. Global Pixel Metrics (Confusion Matrix)
# Flatten arrays to 1D for fast calculation
y_true_flat = y_true.flatten().astype(int)
y_pred_flat = y_pred.flatten().astype(int)
tn, fp, fn, tp = confusion_matrix(y_true_flat, y_pred_flat, labels=[0, 1]).ravel()
tp_total += tp
fp_total += fp
fn_total += fn
tn_total += tn
# 4. Shape Metrics (Per Image)
for b in range(x.shape[0]):
p_m = y_pred[b, 0].astype(bool)
g_m = y_true[b, 0].astype(bool)
# Boundary IoU
boundary_ious.append(compute_boundary_iou(p_m, g_m))
# Hausdorff (Only if both masks are non-empty)
if p_m.sum() > 0 and g_m.sum() > 0:
hausdorff_dists.append(compute_hausdorff_distance(p_m, g_m))
# Progress Bar
if i % 10 == 0: print(".", end="")
print("\n Evaluation Complete.")
# C. Calculate Final Scores
epsilon = 1e-6
# Pixel Scores
pixel_acc = (tp_total + tn_total) / (tp_total + tn_total + fp_total + fn_total + epsilon)
iou = tp_total / (tp_total + fp_total + fn_total + epsilon)
precision = tp_total / (tp_total + fp_total + epsilon)
recall = tp_total / (tp_total + fn_total + epsilon)
f1 = 2 * (precision * recall) / (precision + recall + epsilon)
# Shape Scores
avg_boundary_iou = np.mean(boundary_ious) if boundary_ious else 0.0
avg_hausdorff = np.mean(hausdorff_dists) if hausdorff_dists else 0.0
# Performance
fps = 1.0 / np.mean(inference_times)
# D. Format Results
results = {
"Model": "Prithvi Scratch",
"Path": model_path,
"Pixel_Accuracy": round(pixel_acc, 4),
"IoU (Jaccard)": round(iou, 4),
"F1-Score (Dice)": round(f1, 4),
"Precision": round(precision, 4),
"Recall": round(recall, 4),
"Boundary_IoU": round(avg_boundary_iou, 4),
"Hausdorff_Dist_px": round(avg_hausdorff, 2),
"Inference_Speed_FPS": round(fps, 2)
}
print("\n" + "="*40)
print(" FINAL PERFORMANCE REPORT")
print("="*40)
print(json.dumps(results, indent=4))
print("="*40)
return results
# --- 4. EXECUTION ---
if 'val_loader' in locals():
metrics = evaluate_model_depth(val_loader, CHECKPOINT_PATH)
else:
print(" Error: 'val_loader' is not defined. Please run the Dataset Loading cell first.")
STARTING DEEP EVALUATION Model: /content/drive/MyDrive/Prithvi_Scratch_Results_1/swa_model_final.pth
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True warnings.warn(
Processing Validation Set....
Evaluation Complete.
========================================
FINAL PERFORMANCE REPORT
========================================
{
"Model": "Prithvi Scratch",
"Path": "/content/drive/MyDrive/Prithvi_Scratch_Results_1/swa_model_final.pth",
"Pixel_Accuracy": 0.7088,
"IoU (Jaccard)": 0.7088,
"F1-Score (Dice)": 0.8296,
"Precision": 0.7088,
"Recall": 1.0,
"Boundary_IoU": 0.0,
"Hausdorff_Dist_px": 20.62,
"Inference_Speed_FPS": 17.17
}
========================================
In [ ]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import os
# --- CONFIGURATION ---
# Ensure this matches your setup
SAVE_DIR = '/content/drive/MyDrive/Prithvi_Scratch_Results_1/'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint.pth") # or "swa_model_final.pth"
# --- FIXED VISUALIZATION FUNCTION ---
def visualize_robust(model, loader):
model.eval()
try:
x, y = next(iter(loader))
except StopIteration:
print(" Error: Validation loader is empty.")
return
x, y = x.to(DEVICE), y.to(DEVICE)
with torch.no_grad():
preds = torch.sigmoid(model(x))
masks = (preds > 0.5).float().cpu().numpy()
x_np = x.cpu().numpy()
y_np = y.cpu().numpy()
# --- CRITICAL FIX: DYNAMIC BATCH SIZE ---
# We take the minimum of 3 OR the actual batch size (e.g., 2)
num_samples = min(3, x_np.shape[0])
print(f" Visualizing {num_samples} samples (Batch Size: {x_np.shape[0]})")
# Create subplots dynamically based on actual samples
fig, axs = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
plt.suptitle("RGB Input | Ground Truth | Prediction", fontsize=16)
for i in range(num_samples):
# Handle axes indexing (Matplotlib returns 1D array if rows=1, 2D if rows>1)
if num_samples == 1:
ax = axs
else:
ax = axs[i]
# 1. RGB Composite (Red, Green, Blue)
# Indices: 2=Red, 1=Green, 0=Blue. Time Step: 1 (Middle)
rgb = np.stack([x_np[i, 2, 1], x_np[i, 1, 1], x_np[i, 0, 1]], axis=2)
# Robust Normalization
p2, p98 = np.percentile(rgb, (2, 98))
rgb = np.clip((rgb - p2) / (p98 - p2 + 1e-6), 0, 1)
# 2. Plotting
ax[0].imshow(rgb)
ax[0].set_title(f"Sample {i+1} (RGB)")
ax[1].imshow(y_np[i, 0], cmap='gray')
ax[1].set_title("Ground Truth")
ax[2].imshow(masks[i, 0], cmap='gray')
ax[2].set_title(f"Prediction (Mean: {masks[i,0].mean():.2f})")
for a in ax: a.axis('off')
plt.tight_layout()
plt.show()
if 'model' in locals() and 'val_loader' in locals():
visualize_robust(model, val_loader)
else:
print(" Model or Loader missing. Please re-run the 'Recovery & Visualization' cell above.")
🖼️ Visualizing 2 samples (Batch Size: 2)