78 lines
2.7 KiB
Python
78 lines
2.7 KiB
Python
"""
|
|
Champion Model Tracker
|
|
======================
|
|
Maintains the best-performing model across all autoresearch trials.
|
|
Saves champion model + manifest when a new best is found.
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import shutil
|
|
import time
|
|
from datetime import datetime
|
|
|
|
|
|
class ChampionTracker:
|
|
"""Track and save the best RL model found across all autoresearch trials."""
|
|
|
|
def __init__(self, champion_dir: str):
|
|
self.champion_dir = champion_dir
|
|
self.manifest_path = os.path.join(champion_dir, 'manifest.json')
|
|
os.makedirs(champion_dir, exist_ok=True)
|
|
self._current_best = self._load_manifest()
|
|
|
|
def _load_manifest(self) -> dict:
|
|
"""Load existing champion manifest if it exists."""
|
|
if os.path.exists(self.manifest_path):
|
|
try:
|
|
with open(self.manifest_path) as f:
|
|
return json.load(f)
|
|
except Exception:
|
|
pass
|
|
return {'mean_reward': float('-inf'), 'trial': None}
|
|
|
|
@property
|
|
def best_reward(self) -> float:
|
|
return self._current_best.get('mean_reward', float('-inf'))
|
|
|
|
def update_if_better(self, mean_reward: float, params: dict, model_path: str, trial: int) -> bool:
|
|
"""
|
|
If mean_reward > current best, copy model to champion dir and update manifest.
|
|
Returns True if champion was updated.
|
|
"""
|
|
if mean_reward <= self.best_reward:
|
|
return False
|
|
|
|
# Copy model to champion dir
|
|
champion_model_path = os.path.join(self.champion_dir, 'model.zip')
|
|
if model_path and os.path.exists(model_path):
|
|
try:
|
|
shutil.copy2(model_path, champion_model_path)
|
|
except Exception as e:
|
|
print(f'[Champion] WARNING: Could not copy model: {e}', flush=True)
|
|
champion_model_path = model_path # Fall back to original path
|
|
|
|
# Update manifest
|
|
manifest = {
|
|
'trial': trial,
|
|
'timestamp': datetime.now().isoformat(),
|
|
'params': params,
|
|
'mean_reward': mean_reward,
|
|
'model_path': champion_model_path,
|
|
}
|
|
with open(self.manifest_path, 'w') as f:
|
|
json.dump(manifest, f, indent=2)
|
|
|
|
self._current_best = manifest
|
|
print(f'[Champion] 🏆 NEW BEST! Trial {trial}: mean_reward={mean_reward:.4f} params={params}', flush=True)
|
|
return True
|
|
|
|
def summary(self) -> str:
|
|
if self._current_best['trial'] is None:
|
|
return 'No champion yet.'
|
|
return (
|
|
f"Champion: trial={self._current_best['trial']} "
|
|
f"mean_reward={self._current_best['mean_reward']:.4f} "
|
|
f"params={self._current_best['params']}"
|
|
)
|