class ProgramDatabase:
"""
ClickHouse-backed database for storing and managing programs.
"""
def __init__(
self,
config: DatabaseConfig,
embedding_model: str = "text-embedding-3-small",
read_only: bool = False,
):
self.config = config
self.read_only = read_only
self.client = None
# Connect to ClickHouse
try:
self.client = clickhouse_connect.get_client(
host=self.config.host,
port=self.config.port,
username=self.config.username,
password=self.config.password,
database=self.config.database,
secure=self.config.secure,
connect_timeout=30,
send_receive_timeout=60,
)
logger.info(
f"Connected to ClickHouse at {self.config.host}:{self.config.port} (secure={self.config.secure})"
)
except Exception as e:
logger.error(f"Failed to connect to ClickHouse: {e}")
raise
if not read_only:
self.embedding_client = EmbeddingClient(model_name=embedding_model)
self._create_tables()
else:
self.embedding_client = None
self.last_iteration: int = 0
self.best_program_id: Optional[str] = None
self.beam_search_parent_id: Optional[str] = None
self._schedule_migration: bool = False
self._load_metadata()
# Initialize managers with ClickHouse client
self.island_manager = CombinedIslandManager(
client=self.client,
config=self.config,
)
count = self._count_programs()
logger.debug(f"DB initialized with {count} programs.")
def _create_tables(self):
# Programs table
self.client.command("""
CREATE TABLE IF NOT EXISTS programs (
id String,
code String,
language String,
parent_id String,
archive_inspiration_ids String, -- JSON
top_k_inspiration_ids String, -- JSON
generation Int32,
timestamp Float64,
code_diff String,
combined_score Float64,
public_metrics String, -- JSON
private_metrics String, -- JSON
text_feedback String,
complexity Float64,
embedding Array(Float32),
embedding_pca_2d String, -- JSON
embedding_pca_3d String, -- JSON
embedding_cluster_id Int32,
correct UInt8,
children_count Int32,
metadata String, -- JSON
island_idx Int32,
migration_history String, -- JSON
in_archive UInt8 DEFAULT 0
) ENGINE = ReplacingMergeTree(timestamp)
ORDER BY id
""")
# Ensure 'thought' column exists (migration)
try:
self.client.command(
"ALTER TABLE programs ADD COLUMN IF NOT EXISTS thought String"
)
except Exception as e:
logger.warning(f"Could not add 'thought' column: {e}")
# Archive table (simplified, just tracks IDs in archive)
self.client.command("""
CREATE TABLE IF NOT EXISTS archive (
program_id String,
timestamp DateTime64(3) DEFAULT now()
) ENGINE = ReplacingMergeTree()
ORDER BY program_id
""")
# Metadata store
self.client.command("""
CREATE TABLE IF NOT EXISTS metadata_store (
key String,
value String,
timestamp DateTime64(3) DEFAULT now()
) ENGINE = ReplacingMergeTree(timestamp)
ORDER BY key
""")
logger.debug("ClickHouse tables ensured to exist.")
def _count_programs(self) -> int:
return self.client.command("SELECT count() FROM programs")
def _load_metadata(self):
try:
# Use query() with ORDER BY and LIMIT to get latest value from ReplacingMergeTree
last_iter_result = self.client.query(
"SELECT value FROM metadata_store WHERE key = 'last_iteration' ORDER BY timestamp DESC LIMIT 1"
)
if last_iter_result.result_rows:
self.last_iteration = int(last_iter_result.result_rows[0][0])
else:
self.last_iteration = 0
best_id_result = self.client.query(
"SELECT value FROM metadata_store WHERE key = 'best_program_id' ORDER BY timestamp DESC LIMIT 1"
)
if best_id_result.result_rows:
best_id = best_id_result.result_rows[0][0]
self.best_program_id = (
best_id if best_id and best_id != "None" else None
)
else:
self.best_program_id = None
beam_id_result = self.client.query(
"SELECT value FROM metadata_store WHERE key = 'beam_search_parent_id' ORDER BY timestamp DESC LIMIT 1"
)
if beam_id_result.result_rows:
beam_id = beam_id_result.result_rows[0][0]
self.beam_search_parent_id = (
beam_id if beam_id and beam_id != "None" else None
)
else:
self.beam_search_parent_id = None
except Exception as e:
logger.warning(f"Failed to load metadata: {e}")
def _update_metadata(self, key: str, value: Any):
if self.read_only:
return
val_str = str(value)
self.client.insert(
"metadata_store", [[key, val_str]], column_names=["key", "value"]
)
def add(self, program: Program, verbose: bool = False) -> str:
if self.read_only:
raise PermissionError("Read-only mode")
self.island_manager.assign_island(program)
if program.complexity == 0.0:
try:
metrics = analyze_code_metrics(program.code, program.language)
program.complexity = metrics.get(
"complexity_score", float(len(program.code))
)
if not program.metadata:
program.metadata = {}
program.metadata["code_analysis_metrics"] = metrics
except:
program.complexity = float(len(program.code))
# Serialize fields
row = [
program.id,
program.code,
program.language,
program.parent_id or "",
json.dumps(program.archive_inspiration_ids),
json.dumps(program.top_k_inspiration_ids),
program.generation,
program.timestamp,
program.code_diff or "",
program.combined_score or 0.0,
json.dumps(program.public_metrics),
json.dumps(program.private_metrics),
str(program.text_feedback) if program.text_feedback else "",
program.complexity,
program.embedding,
json.dumps(program.embedding_pca_2d),
json.dumps(program.embedding_pca_3d),
program.embedding_cluster_id or -1,
1 if program.correct else 0,
program.children_count,
json.dumps(program.metadata),
program.island_idx if program.island_idx is not None else -1,
json.dumps(program.migration_history),
1 if program.in_archive else 0,
program.thought,
]
self.client.insert(
"programs",
[row],
column_names=[
"id",
"code",
"language",
"parent_id",
"archive_inspiration_ids",
"top_k_inspiration_ids",
"generation",
"timestamp",
"code_diff",
"combined_score",
"public_metrics",
"private_metrics",
"text_feedback",
"complexity",
"embedding",
"embedding_pca_2d",
"embedding_pca_3d",
"embedding_cluster_id",
"correct",
"children_count",
"metadata",
"island_idx",
"migration_history",
"in_archive",
"thought",
],
)
# Update parent children count (ClickHouse specific: we update by inserting new row with incremented count?
# No, simpler to just rely on count queries or updates. ReplacingMergeTree handles updates by key.
# But we need to read parent first. For now, let's skip incrementing parent count in DB
# and rely on 'SELECT count() FROM programs WHERE parent_id=...' if needed)
self._update_archive(program)
self._update_best_program(program)
self._recompute_embeddings_and_clusters()
if program.generation > self.last_iteration:
self.last_iteration = program.generation
self._update_metadata("last_iteration", self.last_iteration)
if verbose:
self._print_program_summary(program)
if self.island_manager.needs_island_copies(program):
self.island_manager.copy_program_to_islands(program)
# Remove flag in DB? We just inserted it. Maybe update it.
# For ClickHouse, updating metadata means inserting a new row with same ID.
if program.metadata:
program.metadata.pop("_needs_island_copies", None)
self._update_program_metadata(program.id, program.metadata)
if self.island_manager.should_schedule_migration(program):
self._schedule_migration = True
self.check_scheduled_operations()
return program.id
def _update_program_metadata(self, pid: str, metadata: dict):
meta_json = json.dumps(metadata)
self.client.command(
f"ALTER TABLE programs UPDATE metadata = '{meta_json}' WHERE id = '{pid}'"
)
def get(self, program_id: str) -> Optional[Program]:
try:
result = self.client.query(
f"SELECT * FROM programs WHERE id = '{program_id}'"
)
if not result.result_rows:
return None
row = result.result_rows[0]
cols = result.column_names
data = dict(zip(cols, row))
return self._program_from_dict(data)
except Exception as e:
logger.error(f"Error getting program {program_id}: {e}")
return None
def _program_from_dict(self, data: Dict[str, Any]) -> Program:
# Deserialize JSON fields
for field in ["public_metrics", "private_metrics", "metadata"]:
if isinstance(data.get(field), str):
try:
data[field] = json.loads(data[field])
except:
data[field] = {}
for field in [
"archive_inspiration_ids",
"top_k_inspiration_ids",
"embedding_pca_2d",
"embedding_pca_3d",
"migration_history",
]:
if isinstance(data.get(field), str):
try:
data[field] = json.loads(data[field])
except:
data[field] = []
data["correct"] = bool(data.get("correct", 0))
data["in_archive"] = bool(data.get("in_archive", 0))
if "thought" not in data:
data["thought"] = ""
# Handle -1 defaults
if data.get("island_idx") == -1:
data["island_idx"] = None
if data.get("embedding_cluster_id") == -1:
data["embedding_cluster_id"] = None
if data.get("parent_id") == "":
data["parent_id"] = None
return Program.from_dict(data)
def _update_archive(self, program: Program):
if not self.config.archive_size or not program.correct:
return
count = self.client.command("SELECT count() FROM archive")
if count < self.config.archive_size:
self.client.command(
f"INSERT INTO archive (program_id) VALUES ('{program.id}')"
)
else:
# Find worst in archive
# We need to join archive with programs to get scores
worst_res = self.client.query("""
SELECT a.program_id, p.combined_score
FROM archive a
LEFT JOIN programs p ON a.program_id = p.id
ORDER BY p.combined_score ASC LIMIT 1
""")
if worst_res.result_rows:
worst_id, worst_score = worst_res.result_rows[0]
if program.combined_score > worst_score:
self.client.command(
f"ALTER TABLE archive DELETE WHERE program_id = '{worst_id}'"
)
self.client.command(
f"INSERT INTO archive (program_id) VALUES ('{program.id}')"
)
def _update_best_program(self, program: Program):
if not program.correct:
return
if not self.best_program_id:
self.best_program_id = program.id
self._update_metadata("best_program_id", program.id)
return
current_best = self.get(self.best_program_id)
if not current_best or (program.combined_score > current_best.combined_score):
self.best_program_id = program.id
self._update_metadata("best_program_id", program.id)
logger.info(
f"New best program: {program.id} (Score: {program.combined_score})"
)
def get_best_program(self, metric: Optional[str] = None) -> Optional[Program]:
query = "SELECT * FROM programs WHERE correct = 1"
if metric:
# This is tricky with JSON metrics in ClickHouse.
# We'd need JSON extraction functions.
# Assuming basic combined_score for now if metric is complex
logger.warning(
"Custom metric sorting in ClickHouse requires JSON extract logic. Using combined_score."
)
query += " ORDER BY combined_score DESC LIMIT 1"
res = self.client.query(query)
if res.result_rows:
return self._program_from_dict(
dict(zip(res.column_names, res.result_rows[0]))
)
return None
def get_top_programs(
self, n: int = 10, metric: str = "combined_score", correct_only: bool = False
) -> List[Program]:
where = "WHERE correct = 1" if correct_only else "WHERE 1=1"
query = f"SELECT * FROM programs {where} ORDER BY combined_score DESC LIMIT {n}"
res = self.client.query(query)
return [
self._program_from_dict(dict(zip(res.column_names, row)))
for row in res.result_rows
]
def get_programs_by_generation(self, generation: int) -> List[Program]:
"""Get all programs from a specific generation."""
query = f"SELECT * FROM programs WHERE generation = {generation} ORDER BY combined_score DESC"
res = self.client.query(query)
return [
self._program_from_dict(dict(zip(res.column_names, row)))
for row in res.result_rows
]
def _recompute_embeddings_and_clusters(self, num_clusters: int = 4):
try:
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
except ImportError:
logger.warning("scikit-learn not installed, skipping clustering")
return
# Fetch all programs with embeddings
query = "SELECT * FROM programs WHERE length(embedding) > 0"
res = self.client.query(query)
if not res.result_rows:
return
programs = [
self._program_from_dict(dict(zip(res.column_names, row)))
for row in res.result_rows
]
if len(programs) < 3:
return
embeddings = [p.embedding for p in programs]
X = np.array(embeddings)
# PCA 2D
pca_2d = PCA(n_components=2)
X_2d = pca_2d.fit_transform(X)
# PCA 3D
pca_3d = PCA(n_components=3)
X_3d = pca_3d.fit_transform(X)
# KMeans
kmeans = KMeans(n_clusters=min(num_clusters, len(X)), n_init=10)
labels = kmeans.fit_predict(X)
# Update programs
updated_rows = []
for i, program in enumerate(programs):
program.embedding_pca_2d = X_2d[i].tolist()
program.embedding_pca_3d = X_3d[i].tolist()
program.embedding_cluster_id = int(labels[i])
# Update timestamp to ensure this version wins in ReplacingMergeTree
program.timestamp = time.time()
updated_rows.append(
[
program.id,
program.code,
program.language,
program.parent_id or "",
json.dumps(program.archive_inspiration_ids),
json.dumps(program.top_k_inspiration_ids),
program.generation,
program.timestamp,
program.code_diff,
program.combined_score,
json.dumps(program.public_metrics),
json.dumps(program.private_metrics),
program.text_feedback,
program.complexity,
program.embedding,
json.dumps(program.embedding_pca_2d),
json.dumps(program.embedding_pca_3d),
program.embedding_cluster_id,
1 if program.correct else 0,
program.children_count,
json.dumps(program.metadata),
program.island_idx if program.island_idx is not None else -1,
json.dumps(program.migration_history),
1 if program.in_archive else 0,
program.thought,
]
)
# Bulk insert
self.client.insert(
"programs",
updated_rows,
column_names=[
"id",
"code",
"language",
"parent_id",
"archive_inspiration_ids",
"top_k_inspiration_ids",
"generation",
"timestamp",
"code_diff",
"combined_score",
"public_metrics",
"private_metrics",
"text_feedback",
"complexity",
"embedding",
"embedding_pca_2d",
"embedding_pca_3d",
"embedding_cluster_id",
"correct",
"children_count",
"metadata",
"island_idx",
"migration_history",
"in_archive",
"thought",
],
)
logger.info(f"Recomputed embeddings and clusters for {len(programs)} programs")
def check_scheduled_operations(self):
if self._schedule_migration:
self.island_manager.perform_migration(self.last_iteration)
self._schedule_migration = False
def close(self):
if self.client:
self.client.close()
def print_summary(self, console=None):
pass # Todo: update display logic
def _print_program_summary(self, program):
pass
# ... Other methods (sample, compute_similarity) would need similar updates
# Creating a stub for sample to prevent immediate crash if used
def sample(self, *args, **kwargs):
# minimal stub
parent = self.get_best_program()
return parent, [], []