diff --git a/benchmarks/gsm8k_2agent/pipeline_logit_guided.py b/benchmarks/gsm8k_2agent/pipeline_logit_guided.py new file mode 100644 index 0000000..9e7e256 --- /dev/null +++ b/benchmarks/gsm8k_2agent/pipeline_logit_guided.py @@ -0,0 +1,251 @@ +"""Logit-guided pipeline: 2-agent chain with cross-model logit bias. + +Researcher runs on model A (latent steps → extract hidden state → compute logits). +Solver runs on model B (generate with logit bias from mapped source distribution). + +Unlike rosetta (single virtual token in KV-cache), logit-guided distributes +the source signal across the target's entire autoregressive generation. +""" + +import time +from typing import Any, Dict, List + +import torch +from transformers import LogitsProcessorList + +from benchmarks.shared.generation import generate_text, render_prompt, tokenize_prompt +from benchmarks.shared.metrics import gpu_memory_tracker +from .agents import AGENTS, build_latent_prompt +from .evaluate import extract_gold, extract_gsm8k_answer, check_correct + + +def run_logit_guided_pipeline( + conn_a: Any, + model_a: Any, + tokenizer_a: Any, + identity_a: Any, + model_b: Any, + tokenizer_b: Any, + device: str, + avp_map: Any, + question: str, + gold_solution: str, + latent_steps: int = 10, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.95, + verbose: bool = False, + logit_bias_alpha: float = 0.5, + logit_bias_confidence_threshold: float = 0.8, +) -> Dict: + """Run the 2-agent cross-model pipeline with logit-guided decoding. + + Researcher (model A): latent steps → extract hidden state → compute logits + Solver (model B): generate with mapped logit bias (no KV-cache priming) + """ + from avp.rosetta.logit_guided import ( + CrossModelLogitBias, + compute_cross_model_logit_bias, + ) + + with gpu_memory_tracker(device) as mem: + t0 = time.perf_counter() + agent_traces: List[Dict] = [] + total_prompt_tokens = 0 + total_latent_steps = 0 + total_output_tokens = 0 + + researcher = AGENTS[0] + solver = AGENTS[1] + + # --- Agent 1: Researcher on model A (latent steps) --- + messages = build_latent_prompt(researcher.role, question) + prompt_text = render_prompt(tokenizer_a, messages) + input_ids, attention_mask = tokenize_prompt(tokenizer_a, prompt_text, device) + + agent_t0 = time.perf_counter() + prompt_tokens = int(input_ids.shape[-1]) + total_prompt_tokens += prompt_tokens + total_latent_steps += latent_steps + + # Run latent steps and collect hidden states + past_kv, hidden_states = conn_a.generate_latent_steps( + input_ids, latent_steps=latent_steps, attention_mask=attention_mask, + collect_hidden_states=True, + ) + + # Use last hidden state for logit computation + last_hidden = hidden_states[-1].unsqueeze(0) # [1, D_src] + + # Compute logit bias + bias_t0 = time.perf_counter() + + source_lm_head = model_a.get_output_embeddings() + if source_lm_head is None: + source_lm_head = getattr(model_a, "lm_head", None) + + target_vocab_size = model_b.config.vocab_size + bias = compute_cross_model_logit_bias( + source_hidden_state=last_hidden, + source_lm_head_weight=source_lm_head.weight, + avp_map=avp_map, + target_vocab_size=target_vocab_size, + ) + + bias_ms = (time.perf_counter() - bias_t0) * 1000 + nonzero_count = int((bias != 0).sum()) + bias_magnitude = float(bias[bias != 0].abs().mean()) if nonzero_count > 0 else 0.0 + + # Wire size: just the bias tensor (much smaller than text, larger than rosetta embed) + wire_bytes = bias.nelement() * bias.element_size() + + agent_time_ms = (time.perf_counter() - agent_t0) * 1000 + + agent_traces.append({ + "name": researcher.name, + "role": researcher.role, + "prompt_tokens": prompt_tokens, + "latent_steps": latent_steps, + "bias_compute_ms": bias_ms, + "wire_bytes": wire_bytes, + "agent_time_ms": agent_time_ms, + "output": "", + }) + + if verbose: + print(f" [{researcher.name}] latent steps={latent_steps}, " + f"bias compute={bias_ms:.1f}ms, " + f"nonzero={nonzero_count}/{target_vocab_size}, " + f"mean_abs_bias={bias_magnitude:.4f}") + + # Free model A KV-cache + del past_kv, hidden_states + if device == "cuda": + torch.cuda.empty_cache() + + # --- Agent 2: Solver on model B (generate with logit bias) --- + messages = build_latent_prompt(solver.role, question) + prompt_text = render_prompt(tokenizer_b, messages) + input_ids, attention_mask = tokenize_prompt(tokenizer_b, prompt_text, device) + + agent_t0 = time.perf_counter() + prompt_tokens = int(input_ids.shape[-1]) + total_prompt_tokens += prompt_tokens + + # Create logit bias processor + processor = CrossModelLogitBias( + bias=bias, + alpha=logit_bias_alpha, + confidence_threshold=logit_bias_confidence_threshold, + ) + + text, _ = generate_text( + model_b, tokenizer_b, input_ids, attention_mask, device, + past_key_values=None, # No KV-cache priming + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + logits_processor=LogitsProcessorList([processor]), + ) + + output_encoded = tokenizer_b(text, add_special_tokens=False) + output_tokens = len(output_encoded["input_ids"]) + total_output_tokens += output_tokens + agent_time_ms = (time.perf_counter() - agent_t0) * 1000 + + agent_traces.append({ + "name": solver.name, + "role": solver.role, + "prompt_tokens": prompt_tokens, + "output_tokens": output_tokens, + "agent_time_ms": agent_time_ms, + "output": text, + }) + + if verbose: + print(f" [{solver.name}] output ({len(text)} chars): {text[:200]}...") + + wall_time = time.perf_counter() - t0 + + total_tokens = total_prompt_tokens + total_latent_steps + total_output_tokens + tokens_per_sec = total_tokens / wall_time if wall_time > 0 else 0 + + gold = extract_gold(gold_solution) + prediction = extract_gsm8k_answer(agent_traces[-1]["output"]) + correct = check_correct(prediction, gold) + + return { + "question": question, + "gold": gold, + "prediction": prediction, + "raw_output": agent_traces[-1]["output"], + "correct": correct, + "wall_time": wall_time, + "total_prompt_tokens": total_prompt_tokens, + "total_latent_steps": total_latent_steps, + "total_output_tokens": total_output_tokens, + "total_tokens": total_tokens, + "tokens_per_sec": tokens_per_sec, + "peak_memory_mb": mem["peak_memory_mb"], + "bias_compute_ms": bias_ms, + "logit_bias_alpha": logit_bias_alpha, + "logit_bias_confidence_threshold": logit_bias_confidence_threshold, + "bias_nonzero_count": nonzero_count, + "bias_mean_magnitude": bias_magnitude, + "wire_bytes": wire_bytes, + "agents": agent_traces, + "mode": "logit_guided", + } + + +def run_logit_guided_benchmark( + conn_a: Any, + model_a: Any, + tokenizer_a: Any, + identity_a: Any, + model_b: Any, + tokenizer_b: Any, + device: str, + avp_map: Any, + dataset: List[Dict], + latent_steps: int = 10, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.95, + verbose: bool = False, + logit_bias_alpha: float = 0.5, + logit_bias_confidence_threshold: float = 0.8, +) -> List[Dict]: + """Run logit-guided pipeline on a list of GSM8K samples.""" + results = [] + for i, sample in enumerate(dataset): + if verbose: + print(f"\n[Logit-Guided] Sample {i + 1}/{len(dataset)}: " + f"{sample['question'][:80]}...") + + result = run_logit_guided_pipeline( + conn_a, model_a, tokenizer_a, identity_a, + model_b, tokenizer_b, device, avp_map, + question=sample["question"], + gold_solution=sample["answer"], + latent_steps=latent_steps, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + verbose=verbose, + logit_bias_alpha=logit_bias_alpha, + logit_bias_confidence_threshold=logit_bias_confidence_threshold, + ) + results.append(result) + + if verbose: + status = "CORRECT" if result["correct"] else "WRONG" + print(f" => {status} (pred={result['prediction']}, gold={result['gold']}, " + f"time={result['wall_time']:.1f}s)") + else: + correct = sum(1 for r in results if r["correct"]) + print(f" [Logit-Guided a={logit_bias_alpha}] {i + 1}/{len(dataset)} " + f"({correct}/{i + 1} correct, {result['wall_time']:.1f}s)", + flush=True) + + return results diff --git a/benchmarks/gsm8k_2agent/pipeline_mid_layer.py b/benchmarks/gsm8k_2agent/pipeline_mid_layer.py new file mode 100644 index 0000000..77ea0e1 --- /dev/null +++ b/benchmarks/gsm8k_2agent/pipeline_mid_layer.py @@ -0,0 +1,243 @@ +"""Mid-layer injection pipeline: 2-agent chain with intermediate layer injection. + +Researcher runs on model A (latent steps -> extract hidden state -> project). +Solver runs on model B (inject at layer ~75% depth via forward hook -> generate). + +Unlike rosetta (injects projected embedding at layer 0 via inputs_embeds), +mid-layer injects at an intermediate layer, operating directly in the +semantic representation space. +""" + +import time +from typing import Any, Dict, List + +import torch + +from benchmarks.shared.generation import generate_text, render_prompt, tokenize_prompt +from benchmarks.shared.kv_utils import get_past_length +from benchmarks.shared.metrics import gpu_memory_tracker +from .agents import AGENTS, build_latent_prompt +from .evaluate import extract_gold, extract_gsm8k_answer, check_correct + + +def run_mid_layer_pipeline( + conn_a: Any, + model_a: Any, + tokenizer_a: Any, + identity_a: Any, + model_b: Any, + tokenizer_b: Any, + device: str, + avp_map: Any, + question: str, + gold_solution: str, + latent_steps: int = 10, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.95, + verbose: bool = False, + depth_ratio: float = 0.75, +) -> Dict: + """Run the 2-agent cross-model pipeline with mid-layer injection. + + Researcher (model A): latent steps -> extract hidden state -> project + Solver (model B): inject at intermediate layer via forward hook -> generate + """ + from avp.rosetta.mid_layer import ( + compute_injection_layer, + mid_layer_injection_hook, + ) + + with gpu_memory_tracker(device) as mem: + t0 = time.perf_counter() + agent_traces: List[Dict] = [] + total_prompt_tokens = 0 + total_latent_steps = 0 + total_output_tokens = 0 + + researcher = AGENTS[0] + solver = AGENTS[1] + + # --- Agent 1: Researcher on model A (latent steps) --- + messages = build_latent_prompt(researcher.role, question) + prompt_text = render_prompt(tokenizer_a, messages) + input_ids, attention_mask = tokenize_prompt(tokenizer_a, prompt_text, device) + + agent_t0 = time.perf_counter() + prompt_tokens = int(input_ids.shape[-1]) + total_prompt_tokens += prompt_tokens + total_latent_steps += latent_steps + + # Collect hidden states from all latent steps + past_kv, hidden_states = conn_a.generate_latent_steps( + input_ids, latent_steps=latent_steps, attention_mask=attention_mask, + collect_hidden_states=True, + ) + + # Use last hidden state for projection + last_hidden = hidden_states[-1].unsqueeze(0) # [1, D_src] + + # Project to target model space + proj_t0 = time.perf_counter() + projected, proj_metrics = conn_a.project_hidden_for_cross_model( + last_hidden, avp_map, return_metrics=True, + ) + projection_ms = (time.perf_counter() - proj_t0) * 1000 + + # Ensure [1, D] shape for injection + if projected.dim() == 1: + projected = projected.unsqueeze(0) + if projected.dim() == 3: + projected = projected.squeeze(0)[-1:, :] + + wire_bytes = projected.nelement() * projected.element_size() + agent_time_ms = (time.perf_counter() - agent_t0) * 1000 + + # Compute injection layer + target_num_layers = model_b.config.num_hidden_layers + injection_layer = compute_injection_layer(target_num_layers, depth_ratio) + + agent_traces.append({ + "name": researcher.name, + "role": researcher.role, + "prompt_tokens": prompt_tokens, + "latent_steps": latent_steps, + "projection_ms": projection_ms, + "wire_bytes": wire_bytes, + "agent_time_ms": agent_time_ms, + "injection_layer": injection_layer, + "target_num_layers": target_num_layers, + "output": "", + }) + + if verbose: + print(f" [{researcher.name}] latent steps={latent_steps}, " + f"projection={projection_ms:.1f}ms, " + f"inject at layer {injection_layer}/{target_num_layers} " + f"({100*injection_layer/target_num_layers:.0f}% depth)") + + # Free model A KV-cache + del past_kv, hidden_states + if device == "cuda": + torch.cuda.empty_cache() + + # --- Agent 2: Solver on model B (mid-layer injection + generate) --- + messages = build_latent_prompt(solver.role, question) + prompt_text = render_prompt(tokenizer_b, messages) + input_ids, attention_mask = tokenize_prompt(tokenizer_b, prompt_text, device) + + agent_t0 = time.perf_counter() + prompt_tokens = int(input_ids.shape[-1]) + total_prompt_tokens += prompt_tokens + + # Generate with mid-layer injection hook + inject_hidden = projected.to(device).to(model_b.dtype) + + with mid_layer_injection_hook(model_b, injection_layer, inject_hidden): + text, _ = generate_text( + model_b, tokenizer_b, input_ids, attention_mask, device, + past_key_values=None, # No KV-cache priming + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + ) + + output_encoded = tokenizer_b(text, add_special_tokens=False) + output_tokens = len(output_encoded["input_ids"]) + total_output_tokens += output_tokens + agent_time_ms = (time.perf_counter() - agent_t0) * 1000 + + agent_traces.append({ + "name": solver.name, + "role": solver.role, + "prompt_tokens": prompt_tokens, + "output_tokens": output_tokens, + "agent_time_ms": agent_time_ms, + "output": text, + }) + + if verbose: + print(f" [{solver.name}] output ({len(text)} chars): {text[:200]}...") + + wall_time = time.perf_counter() - t0 + + total_tokens = total_prompt_tokens + total_latent_steps + total_output_tokens + tokens_per_sec = total_tokens / wall_time if wall_time > 0 else 0 + + gold = extract_gold(gold_solution) + prediction = extract_gsm8k_answer(agent_traces[-1]["output"]) + correct = check_correct(prediction, gold) + + return { + "question": question, + "gold": gold, + "prediction": prediction, + "raw_output": agent_traces[-1]["output"], + "correct": correct, + "wall_time": wall_time, + "total_prompt_tokens": total_prompt_tokens, + "total_latent_steps": total_latent_steps, + "total_output_tokens": total_output_tokens, + "total_tokens": total_tokens, + "tokens_per_sec": tokens_per_sec, + "peak_memory_mb": mem["peak_memory_mb"], + "projection_overhead_ms": projection_ms, + "projection_wire_bytes": wire_bytes, + "injection_layer": injection_layer, + "depth_ratio": depth_ratio, + "hidden_state_norm": float(proj_metrics["hidden_state_norm"].mean()) if "hidden_state_norm" in proj_metrics else None, + "nearest_cos_sim": float(proj_metrics["nearest_cos_sim"].mean()) if "nearest_cos_sim" in proj_metrics else None, + "agents": agent_traces, + "mode": "mid_layer", + } + + +def run_mid_layer_benchmark( + conn_a: Any, + model_a: Any, + tokenizer_a: Any, + identity_a: Any, + model_b: Any, + tokenizer_b: Any, + device: str, + avp_map: Any, + dataset: List[Dict], + latent_steps: int = 10, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.95, + verbose: bool = False, + depth_ratio: float = 0.75, +) -> List[Dict]: + """Run mid-layer pipeline on a list of GSM8K samples.""" + results = [] + for i, sample in enumerate(dataset): + if verbose: + print(f"\n[Mid-Layer] Sample {i + 1}/{len(dataset)}: " + f"{sample['question'][:80]}...") + + result = run_mid_layer_pipeline( + conn_a, model_a, tokenizer_a, identity_a, + model_b, tokenizer_b, device, avp_map, + question=sample["question"], + gold_solution=sample["answer"], + latent_steps=latent_steps, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + verbose=verbose, + depth_ratio=depth_ratio, + ) + results.append(result) + + if verbose: + status = "CORRECT" if result["correct"] else "WRONG" + print(f" => {status} (pred={result['prediction']}, gold={result['gold']}, " + f"time={result['wall_time']:.1f}s)") + else: + correct = sum(1 for r in results if r["correct"]) + print(f" [Mid-Layer d={depth_ratio}] {i + 1}/{len(dataset)} " + f"({correct}/{i + 1} correct, {result['wall_time']:.1f}s)", + flush=True) + + return results diff --git a/benchmarks/gsm8k_2agent/pipeline_trained.py b/benchmarks/gsm8k_2agent/pipeline_trained.py new file mode 100644 index 0000000..d989a83 --- /dev/null +++ b/benchmarks/gsm8k_2agent/pipeline_trained.py @@ -0,0 +1,235 @@ +"""Trained projection pipeline: 2-agent chain with learned per-layer projections. + +Researcher runs on model A (latent steps → extract hidden state). +Solver runs on model B (per-layer hooks inject trained projections → generate answer). + +Requires a pre-trained AVPMap with layer_weights/layer_biases/layer_gates. +""" + +import time +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F + +from benchmarks.shared.generation import generate_text, render_prompt, tokenize_prompt +from benchmarks.shared.kv_utils import get_past_length +from benchmarks.shared.metrics import gpu_memory_tracker +from .agents import AGENTS, build_latent_prompt +from .evaluate import extract_gold, extract_gsm8k_answer, check_correct + + +def run_trained_pipeline( + conn_a: Any, + model_a: Any, + tokenizer_a: Any, + identity_a: Any, + model_b: Any, + tokenizer_b: Any, + device: str, + avp_map: Any, + question: str, + gold_solution: str, + latent_steps: int = 10, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.95, + verbose: bool = False, +) -> Dict: + """Run the 2-agent pipeline with trained per-layer projection. + + Researcher (model A): latent steps → extract hidden state + Solver (model B): per-layer hooks inject trained projections → generate text + """ + from avp.rosetta.trained_hooks import trained_multi_layer_hook + + with gpu_memory_tracker(device) as mem: + t0 = time.perf_counter() + agent_traces: List[Dict] = [] + total_prompt_tokens = 0 + total_latent_steps = 0 + total_output_tokens = 0 + + researcher = AGENTS[0] + solver = AGENTS[1] + + # --- Agent 1: Researcher on model A (latent steps) --- + messages = build_latent_prompt(researcher.role, question) + prompt_text = render_prompt(tokenizer_a, messages) + input_ids, attention_mask = tokenize_prompt(tokenizer_a, prompt_text, device) + + agent_t0 = time.perf_counter() + prompt_tokens = int(input_ids.shape[-1]) + total_prompt_tokens += prompt_tokens + total_latent_steps += latent_steps + + past_kv = conn_a.generate_latent_steps( + input_ids, latent_steps=latent_steps, attention_mask=attention_mask, + ) + + # Extract last hidden state + proj_t0 = time.perf_counter() + past_len = get_past_length(past_kv) + dummy_mask = torch.ones((1, past_len + 1), dtype=torch.long, device=device) + eos_id = tokenizer_a.eos_token_id or 0 + dummy_ids = torch.tensor([[eos_id]], device=device) + with torch.no_grad(): + out = model_a( + input_ids=dummy_ids, + attention_mask=dummy_mask, + past_key_values=past_kv, + output_hidden_states=True, + return_dict=True, + ) + src_hidden = out.hidden_states[-1][:, -1, :].float() # [1, D_src] + + # Pre-compute per-layer projections + layer_projections = [] + active_count = 0 + for i, (w, b, gate) in enumerate( + zip(avp_map.layer_weights, avp_map.layer_biases, avp_map.layer_gates) + ): + if gate < 0.01: + layer_projections.append(None) + continue + projected = F.linear( + src_hidden, w.to(device), b.to(device) + ) # [1, D_tgt] + layer_projections.append((projected, gate)) + active_count += 1 + + projection_ms = (time.perf_counter() - proj_t0) * 1000 + + agent_time_ms = (time.perf_counter() - agent_t0) * 1000 + agent_traces.append({ + "name": researcher.name, + "role": researcher.role, + "prompt_tokens": prompt_tokens, + "latent_steps": latent_steps, + "projection_ms": projection_ms, + "active_layers": active_count, + "total_layers": len(avp_map.layer_gates), + "agent_time_ms": agent_time_ms, + "output": "", + }) + + if verbose: + print(f" [{researcher.name}] latent steps={latent_steps}, " + f"projection={projection_ms:.1f}ms, " + f"active layers={active_count}/{len(avp_map.layer_gates)}") + + # Free model A KV-cache + del past_kv + if device == "cuda": + torch.cuda.empty_cache() + + # --- Agent 2: Solver on model B (trained hooks, generate) --- + messages = build_latent_prompt(solver.role, question) + prompt_text = render_prompt(tokenizer_b, messages) + input_ids, attention_mask = tokenize_prompt(tokenizer_b, prompt_text, device) + + agent_t0 = time.perf_counter() + prompt_tokens = int(input_ids.shape[-1]) + total_prompt_tokens += prompt_tokens + + with trained_multi_layer_hook(model_b, layer_projections): + text, _ = generate_text( + model_b, tokenizer_b, input_ids, attention_mask, device, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + ) + + output_encoded = tokenizer_b(text, add_special_tokens=False) + output_tokens = len(output_encoded["input_ids"]) + total_output_tokens += output_tokens + agent_time_ms = (time.perf_counter() - agent_t0) * 1000 + + agent_traces.append({ + "name": solver.name, + "role": solver.role, + "prompt_tokens": prompt_tokens, + "output_tokens": output_tokens, + "agent_time_ms": agent_time_ms, + "output": text, + }) + + if verbose: + print(f" [{solver.name}] output ({len(text)} chars): {text[:200]}...") + + wall_time = time.perf_counter() - t0 + + total_tokens = total_prompt_tokens + total_latent_steps + total_output_tokens + tokens_per_sec = total_tokens / wall_time if wall_time > 0 else 0 + + gold = extract_gold(gold_solution) + prediction = extract_gsm8k_answer(agent_traces[-1]["output"]) + correct = check_correct(prediction, gold) + + return { + "question": question, + "gold": gold, + "prediction": prediction, + "raw_output": agent_traces[-1]["output"], + "correct": correct, + "wall_time": wall_time, + "total_prompt_tokens": total_prompt_tokens, + "total_latent_steps": total_latent_steps, + "total_output_tokens": total_output_tokens, + "total_tokens": total_tokens, + "tokens_per_sec": tokens_per_sec, + "peak_memory_mb": mem["peak_memory_mb"], + "projection_overhead_ms": projection_ms, + "active_layers": active_count, + "total_layers": len(avp_map.layer_gates), + "agents": agent_traces, + "mode": "trained", + } + + +def run_trained_benchmark( + conn_a: Any, + model_a: Any, + tokenizer_a: Any, + identity_a: Any, + model_b: Any, + tokenizer_b: Any, + device: str, + avp_map: Any, + dataset: List[Dict], + latent_steps: int = 10, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.95, + verbose: bool = False, +) -> List[Dict]: + """Run trained projection pipeline on a list of GSM8K samples.""" + results = [] + for i, sample in enumerate(dataset): + if verbose: + print(f"\n[Trained] Sample {i + 1}/{len(dataset)}: {sample['question'][:80]}...") + + result = run_trained_pipeline( + conn_a, model_a, tokenizer_a, identity_a, + model_b, tokenizer_b, device, avp_map, + question=sample["question"], + gold_solution=sample["answer"], + latent_steps=latent_steps, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + verbose=verbose, + ) + results.append(result) + + if verbose: + status = "CORRECT" if result["correct"] else "WRONG" + print(f" => {status} (pred={result['prediction']}, gold={result['gold']}, " + f"time={result['wall_time']:.1f}s)") + else: + correct = sum(1 for r in results if r["correct"]) + print(f" [Trained] {i + 1}/{len(dataset)} " + f"({correct}/{i + 1} correct, {result['wall_time']:.1f}s)", + flush=True) + + return results diff --git a/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py b/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py index f077d52..93d7352 100644 --- a/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py +++ b/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py @@ -39,7 +39,8 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument( "--mode", - choices=["latent", "text", "direct", "rosetta", "text_cross_model", "both", "all"], + choices=["latent", "text", "direct", "rosetta", "logit_guided", + "mid_layer", "trained", "text_cross_model", "both", "all"], default="all", help="Pipeline(s) to run (default: all)", ) @@ -64,6 +65,10 @@ def parse_args() -> argparse.Namespace: help="Softmax temperature for cross-model projection (default: 1.0)") parser.add_argument("--num_transfer_states", type=int, default=1, help="Number of hidden states to transfer in rosetta mode (default: 1)") + parser.add_argument("--logit_bias_alpha", type=float, default=0.5, + help="Logit bias scaling factor for logit_guided mode (default: 0.5)") + parser.add_argument("--logit_bias_confidence_threshold", type=float, default=0.8, + help="Confidence threshold for logit bias gating (default: 0.8)") return parser.parse_args() @@ -104,6 +109,8 @@ def run_benchmark(config: dict) -> dict: output_dir = config.get("output_dir") projection_temperature = config.get("projection_temperature", 1.0) num_transfer_states = config.get("num_transfer_states", 1) + logit_bias_alpha = config.get("logit_bias_alpha", 0.5) + logit_bias_confidence_threshold = config.get("logit_bias_confidence_threshold", 0.8) model_b_name = config.get("model_b", "Qwen/Qwen2.5-0.5B-Instruct") @@ -111,12 +118,15 @@ def run_benchmark(config: dict) -> dict: run_latent = mode in ("latent", "both", "all") run_text = mode in ("text", "both", "all") run_rosetta = mode in ("rosetta", "all") + run_logit_guided = mode in ("logit_guided", "all") + run_mid_layer = mode in ("mid_layer", "all") + run_trained = mode in ("trained",) # not in "all" — requires training run_text_cross_model = mode in ("text_cross_model", "all") print(f"Device: {device}") print(f"Mode: {mode}") print(f"Model A: {model_name}") - if run_rosetta or run_text_cross_model: + if run_rosetta or run_logit_guided or run_mid_layer or run_trained or run_text_cross_model: print(f"Model B: {model_b_name}") print(f"Samples: {max_samples}") print(f"Latent steps: {latent_steps}") @@ -124,7 +134,9 @@ def run_benchmark(config: dict) -> dict: print(f"Temperature: {temperature}") print(f"Seed: {seed}") print(f"Pipelines: direct={run_direct}, text={run_text}, latent={run_latent}, " - f"rosetta={run_rosetta}, text_cross_model={run_text_cross_model}") + f"rosetta={run_rosetta}, logit_guided={run_logit_guided}, " + f"mid_layer={run_mid_layer}, trained={run_trained}, " + f"text_cross_model={run_text_cross_model}") print() dataset = load_dataset(max_samples) @@ -134,6 +146,9 @@ def run_benchmark(config: dict) -> dict: latent_results = None text_results = None rosetta_results = None + logit_guided_results = None + mid_layer_results = None + trained_results = None text_cross_model_results = None if run_direct: @@ -182,7 +197,7 @@ def run_benchmark(config: dict) -> dict: # Load model B if needed for cross-model modes model_b = tokenizer_b = connector_b = identity_b = None - if run_rosetta or run_text_cross_model: + if run_rosetta or run_logit_guided or run_mid_layer or run_trained or run_text_cross_model: model_b, tokenizer_b, connector_b, identity_b = load_model(model_b_name, device) if run_text_cross_model: @@ -235,6 +250,117 @@ def run_benchmark(config: dict) -> dict: num_transfer_states=num_transfer_states, ) + if run_logit_guided: + from benchmarks.gsm8k_2agent.pipeline_logit_guided import run_logit_guided_benchmark + from avp.rosetta.calibrate import calibrate + + print("\n" + "=" * 50) + print("Running LOGIT-GUIDED (cross-model logit bias) pipeline...") + print(f" Model A (Researcher): {model_name}") + print(f" Model B (Solver): {model_b_name}") + print(f" Alpha: {logit_bias_alpha}") + print(f" Confidence threshold: {logit_bias_confidence_threshold}") + print("=" * 50) + set_seed(seed) + + # Calibrate (reuse if already done for rosetta) + if 'avp_map' not in dir() or avp_map is None: + print("Calibrating Rosetta Stone projection...") + avp_map = calibrate( + source_model=model, target_model=model_b, + source_tokenizer=tokenizer, target_tokenizer=tokenizer_b, + device=device, + ) + print(f" Method: {avp_map.method.value}, " + f"validation_score: {avp_map.validation_score:.4f}, " + f"{avp_map.source_dim}d → {avp_map.target_dim}d") + + logit_guided_results = run_logit_guided_benchmark( + conn_a=connector, model_a=model, tokenizer_a=tokenizer, + identity_a=identity, model_b=model_b, tokenizer_b=tokenizer_b, + device=device, avp_map=avp_map, dataset=dataset, + latent_steps=latent_steps, max_new_tokens=max_new_tokens, + temperature=temperature, top_p=top_p, verbose=verbose, + logit_bias_alpha=logit_bias_alpha, + logit_bias_confidence_threshold=logit_bias_confidence_threshold, + ) + + if run_mid_layer: + from benchmarks.gsm8k_2agent.pipeline_mid_layer import run_mid_layer_benchmark + from avp.rosetta.calibrate import calibrate + + print("\n" + "=" * 50) + print("Running MID-LAYER (cross-model mid-layer injection) pipeline...") + print(f" Model A (Researcher): {model_name}") + print(f" Model B (Solver): {model_b_name}") + print(f" Depth ratio: 0.75") + print("=" * 50) + set_seed(seed) + + # Calibrate (reuse if already done) + if 'avp_map' not in dir() or avp_map is None: + print("Calibrating Rosetta Stone projection...") + avp_map = calibrate( + source_model=model, target_model=model_b, + source_tokenizer=tokenizer, target_tokenizer=tokenizer_b, + device=device, + ) + print(f" Method: {avp_map.method.value}, " + f"validation_score: {avp_map.validation_score:.4f}, " + f"{avp_map.source_dim}d -> {avp_map.target_dim}d") + + mid_layer_results = run_mid_layer_benchmark( + conn_a=connector, model_a=model, tokenizer_a=tokenizer, + identity_a=identity, model_b=model_b, tokenizer_b=tokenizer_b, + device=device, avp_map=avp_map, dataset=dataset, + latent_steps=latent_steps, max_new_tokens=max_new_tokens, + temperature=temperature, top_p=top_p, verbose=verbose, + ) + + if run_trained: + from benchmarks.gsm8k_2agent.pipeline_trained import run_trained_benchmark + from avp.rosetta.train import train_projector, TrainConfig + + print("\n" + "=" * 50) + print("Running TRAINED (per-layer learned projection) pipeline...") + print(f" Model A (Researcher): {model_name}") + print(f" Model B (Solver): {model_b_name}") + print("=" * 50) + + # Training phase + train_config = TrainConfig( + num_samples=config.get("train_samples", 2000), + batch_size=config.get("train_batch_size", 4), + num_epochs=config.get("train_epochs", 2), + learning_rate=config.get("train_lr", 1e-4), + gate_init=config.get("train_gate_init", -5.0), + gate_reg_weight=config.get("train_gate_reg", 0.01), + ) + print(f"Training projector: {train_config.num_samples} samples, " + f"{train_config.num_epochs} epochs...") + + trained_map = train_projector( + source_model=model, + target_model=model_b, + source_tokenizer=tokenizer, + target_tokenizer=tokenizer_b, + device=device, + config=train_config, + ) + active = [i for i, g in enumerate(trained_map.layer_gates) if g > 0.01] + print(f"Training complete. Active layers: {len(active)}/{len(trained_map.layer_gates)}") + print(f"Validation score: {trained_map.validation_score:.4f}") + + set_seed(seed) + + trained_results = run_trained_benchmark( + conn_a=connector, model_a=model, tokenizer_a=tokenizer, + identity_a=identity, model_b=model_b, tokenizer_b=tokenizer_b, + device=device, avp_map=trained_map, dataset=dataset, + latent_steps=latent_steps, max_new_tokens=max_new_tokens, + temperature=temperature, top_p=top_p, verbose=verbose, + ) + # Free model B to reclaim GPU memory if model_b is not None: del model_b, tokenizer_b, connector_b, identity_b @@ -254,6 +380,12 @@ def run_benchmark(config: dict) -> dict: modes.append(("Text", 13, text_results)) if rosetta_results is not None: modes.append(("Rosetta", 13, rosetta_results)) + if logit_guided_results is not None: + modes.append(("Logit-Guided", 13, logit_guided_results)) + if mid_layer_results is not None: + modes.append(("Mid-Layer", 13, mid_layer_results)) + if trained_results is not None: + modes.append(("Trained", 13, trained_results)) if text_cross_model_results is not None: modes.append(("Text Cross-Model", 16, text_cross_model_results)) @@ -267,6 +399,12 @@ def run_benchmark(config: dict) -> dict: available["latent"] = latent_results if rosetta_results is not None: available["rosetta"] = rosetta_results + if logit_guided_results is not None: + available["logit_guided"] = logit_guided_results + if mid_layer_results is not None: + available["mid_layer"] = mid_layer_results + if trained_results is not None: + available["trained"] = trained_results if text_cross_model_results is not None: available["text_cross_model"] = text_cross_model_results agreement_data = compute_agreement(available) if len(available) > 1 else None @@ -289,7 +427,7 @@ def run_benchmark(config: dict) -> dict: "config": { "benchmark": "gsm8k_2agent", "model_a": model_name, - "model_b": model_b_name if (run_rosetta or run_text_cross_model) else None, + "model_b": model_b_name if (run_rosetta or run_logit_guided or run_mid_layer or run_trained or run_text_cross_model) else None, "device": device, "mode": mode, "max_samples": max_samples, @@ -320,6 +458,21 @@ def run_benchmark(config: dict) -> dict: "summary": compute_stats(rosetta_results), "samples": rosetta_results, } + if logit_guided_results is not None: + output_data["logit_guided"] = { + "summary": compute_stats(logit_guided_results), + "samples": logit_guided_results, + } + if mid_layer_results is not None: + output_data["mid_layer"] = { + "summary": compute_stats(mid_layer_results), + "samples": mid_layer_results, + } + if trained_results is not None: + output_data["trained"] = { + "summary": compute_stats(trained_results), + "samples": trained_results, + } if text_cross_model_results is not None: output_data["text_cross_model"] = { "summary": compute_stats(text_cross_model_results), diff --git a/benchmarks/hotpotqa/agents.py b/benchmarks/hotpotqa/agents.py index 4aa298e..7419392 100644 --- a/benchmarks/hotpotqa/agents.py +++ b/benchmarks/hotpotqa/agents.py @@ -63,10 +63,12 @@ def build_latent_prompt( role: str, question: str, paragraphs_text: str = "", + key_context: str = "", ) -> List[Dict[str, str]]: """Build chat messages for latent mode agents. In latent mode, prior context is carried via KV-cache. + key_context: optional key tokens extracted by attention for hybrid mode. """ if role == "finder": user_prompt = ( @@ -80,15 +82,28 @@ def build_latent_prompt( elif role == "answerer": # In latent mode, the Answerer gets the question but NOT the paragraphs. # The Finder's KV-cache carries the model's understanding of the context. - user_prompt = ( - f"You are an Answerer Agent. A Finder agent has already analyzed " - f"the relevant paragraphs for you. Use their analysis to answer " - f"the question.\n\n" - f"## Question: {question}\n\n" - f"Based on the analysis provided, give a short, precise answer " - f"(a few words at most).\n\n" - f"Answer:" - ) + if key_context: + # Hybrid mode: include key tokens as context + user_prompt = ( + f"You are an Answerer Agent. A Finder agent has already analyzed " + f"the relevant paragraphs for you. Use their analysis and the " + f"key context below to answer the question.\n\n" + f"## Key Context: {key_context}\n\n" + f"## Question: {question}\n\n" + f"Based on the analysis and context provided, give a short, precise answer " + f"(a few words at most).\n\n" + f"Answer:" + ) + else: + user_prompt = ( + f"You are an Answerer Agent. A Finder agent has already analyzed " + f"the relevant paragraphs for you. Use their analysis to answer " + f"the question.\n\n" + f"## Question: {question}\n\n" + f"Based on the analysis provided, give a short, precise answer " + f"(a few words at most).\n\n" + f"Answer:" + ) elif role == "decomposer": user_prompt = ( f"You are a Decomposer Agent. Break the following multi-hop question " diff --git a/benchmarks/hotpotqa/pipeline_rosetta.py b/benchmarks/hotpotqa/pipeline_rosetta.py index da4c3f3..1514cbe 100644 --- a/benchmarks/hotpotqa/pipeline_rosetta.py +++ b/benchmarks/hotpotqa/pipeline_rosetta.py @@ -18,6 +18,67 @@ from .evaluate import exact_match, extract_answer, token_f1 +def find_content_start(input_ids: Any, tokenizer: Any) -> int: + """Find the token position where paragraph content begins. + + Searches for "## Paragraphs:" or "Paragraphs:" marker in the tokenized prompt. + Falls back to skipping the first 20% of tokens (covers system + instruction). + """ + # Decode to find the marker position in text, then map back to tokens + full_text = tokenizer.decode(input_ids[0], skip_special_tokens=False) + for marker in ["## Paragraphs:", "Paragraphs:", "## paragraphs:"]: + marker_pos = full_text.find(marker) + if marker_pos >= 0: + # Count tokens up to the marker (+ marker itself) + prefix = full_text[:marker_pos + len(marker)] + prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) + return len(prefix_tokens) + + # Fallback: skip first 20% of tokens (system + instruction) + return int(input_ids.shape[-1] * 0.2) + + +def extract_key_tokens( + attention_weights: Any, + input_ids: Any, + tokenizer: Any, + prompt_tokens: int, + k: int = 64, +) -> str: + """Extract top-K important tokens from attention weights. + + Uses attention from the final token (over the full KV-cache) to score + input token importance. Skips system prompt and instruction tokens to + avoid attention sinks. Returns decoded text of the top-K tokens. + + Args: + attention_weights: Last layer attention, shape [batch, heads, 1, seq_len] + input_ids: Original input token IDs, shape [1, prompt_tokens] + tokenizer: Source model tokenizer + prompt_tokens: Number of original prompt tokens (before latent steps) + k: Number of tokens to extract + """ + # Average across heads: [seq_len] + attn_scores = attention_weights[0, :, -1, :].mean(dim=0) + + # Only score content tokens (skip system prompt + instruction to avoid attention sinks) + content_start = find_content_start(input_ids, tokenizer) + prompt_scores = attn_scores[:prompt_tokens].clone() + prompt_scores[:content_start] = -float("inf") # Zero out template tokens + + # Select top-K by attention (capped at available content tokens) + content_tokens = prompt_tokens - content_start + k = min(k, max(content_tokens, 1)) + topk_indices = prompt_scores.topk(k).indices + # Sort by position to maintain reading order + topk_indices = topk_indices.sort().values + + # Extract and decode + key_ids = input_ids[0, topk_indices] + key_text = tokenizer.decode(key_ids, skip_special_tokens=True) + return key_text + + def run_rosetta_pipeline( conn_a: Any, model_a: Any, @@ -37,6 +98,7 @@ def run_rosetta_pipeline( verbose: bool = False, projection_temperature: float = 1.0, num_transfer_states: int = 1, + hybrid_k: int = 0, ) -> Dict: """Run the 2-agent cross-model pipeline on a single HotpotQA problem. @@ -66,6 +128,8 @@ def run_rosetta_pipeline( total_latent_steps += latent_steps attention_entropy = None + key_text = None + hybrid_text_tokens = 0 if num_transfer_states > 1: # Multi-embedding: collect hidden states from all latent steps past_kv, hidden_states = conn_a.generate_latent_steps( @@ -95,6 +159,7 @@ def run_rosetta_pipeline( dummy_mask = torch.ones((1, past_len + 1), dtype=torch.long, device=device) eos_id = tokenizer_a.eos_token_id or 0 dummy_ids = torch.tensor([[eos_id]], device=device) + with torch.no_grad(): out = model_a( input_ids=dummy_ids, @@ -104,6 +169,7 @@ def run_rosetta_pipeline( output_attentions=True, return_dict=True, ) + last_hidden = out.hidden_states[-1][:, -1, :] # [1, D_src] # Compute attention entropy from last layer @@ -113,6 +179,18 @@ def run_rosetta_pipeline( attn_ent = -(last_attn * attn_log).sum(dim=-1) # [batch, heads] attention_entropy = float(attn_ent.mean()) + # Extract key tokens for hybrid mode + key_text = None + if hybrid_k > 0 and out.attentions: + key_text = extract_key_tokens( + out.attentions[-1], input_ids, tokenizer_a, + prompt_tokens=prompt_tokens, k=hybrid_k, + ) + if verbose: + print(f" [Hybrid] Extracted {hybrid_k} key tokens: {key_text[:100]}...") + elif hybrid_k > 0: + print(f" WARNING: output_attentions returned empty — hybrid extraction skipped") + # Project to target model space projected, proj_metrics = conn_a.project_hidden_for_cross_model( last_hidden, avp_map, temperature=projection_temperature, @@ -150,6 +228,7 @@ def run_rosetta_pipeline( # --- Agent 2: Answerer on model B (inject projected embed, generate) --- inject_t0 = time.perf_counter() embed_input = rosetta_embeds.to(device).to(model_b.dtype) + embed_mask = torch.ones( (1, embed_input.shape[1]), dtype=torch.long, device=device, ) @@ -163,7 +242,15 @@ def run_rosetta_pipeline( past_kv_b = prime_out.past_key_values injection_ms = (time.perf_counter() - inject_t0) * 1000 - messages = build_latent_prompt(answerer.role, question) + # Hybrid: inject key text as context in Agent B's prompt (via input_ids) + hybrid_text_tokens = 0 + if key_text: + hybrid_text_tokens = len(tokenizer_b.encode(key_text, add_special_tokens=False)) + messages = build_latent_prompt( + answerer.role, question, key_context=key_text, + ) + else: + messages = build_latent_prompt(answerer.role, question) prompt_text = render_prompt(tokenizer_b, messages) input_ids, attention_mask = tokenize_prompt(tokenizer_b, prompt_text, device) @@ -230,8 +317,11 @@ def run_rosetta_pipeline( "hidden_state_norm": float(proj_metrics["hidden_state_norm"].mean()) if "hidden_state_norm" in proj_metrics else None, "nearest_cos_sim": float(proj_metrics["nearest_cos_sim"].mean()) if "nearest_cos_sim" in proj_metrics else None, "attention_entropy": attention_entropy, + "hybrid_k": hybrid_k, + "hybrid_text_tokens": hybrid_text_tokens if hybrid_k > 0 else 0, + "hybrid_key_text": key_text if hybrid_k > 0 else None, "agents": agent_traces, - "mode": "rosetta", + "mode": "hybrid" if hybrid_k > 0 else "rosetta", } @@ -252,12 +342,14 @@ def run_rosetta_benchmark( verbose: bool = False, projection_temperature: float = 1.0, num_transfer_states: int = 1, + hybrid_k: int = 0, ) -> List[Dict]: """Run rosetta-mode pipeline on HotpotQA samples.""" results = [] + mode_label = f"Hybrid K={hybrid_k}" if hybrid_k > 0 else "Rosetta" for i, sample in enumerate(dataset): if verbose: - print(f"\n[Rosetta] Sample {i + 1}/{len(dataset)}: {sample['question'][:80]}...") + print(f"\n[{mode_label}] Sample {i + 1}/{len(dataset)}: {sample['question'][:80]}...") result = run_rosetta_pipeline( conn_a, model_a, tokenizer_a, identity_a, @@ -272,6 +364,7 @@ def run_rosetta_benchmark( verbose=verbose, projection_temperature=projection_temperature, num_transfer_states=num_transfer_states, + hybrid_k=hybrid_k, ) results.append(result) @@ -285,7 +378,7 @@ def run_rosetta_benchmark( correct = sum(1 for r in results if r["exact_match"]) f1s = [r["f1"] for r in results] mean_f1 = sum(f1s) / len(f1s) - print(f" [Rosetta] {i + 1}/{len(dataset)} " + print(f" [{mode_label}] {i + 1}/{len(dataset)} " f"(EM={correct}/{i + 1}, F1={mean_f1:.2f}, {result['wall_time']:.1f}s)", flush=True) diff --git a/benchmarks/hotpotqa/run_hotpotqa.py b/benchmarks/hotpotqa/run_hotpotqa.py index 622e8de..b5626ba 100644 --- a/benchmarks/hotpotqa/run_hotpotqa.py +++ b/benchmarks/hotpotqa/run_hotpotqa.py @@ -272,6 +272,7 @@ def run_benchmark(config: dict) -> dict: output_dir = config.get("output_dir") projection_temperature = config.get("projection_temperature", 1.0) num_transfer_states = config.get("num_transfer_states", 1) + hybrid_k = config.get("hybrid_k", 0) model_b_name = config.get("model_b", "") @@ -297,7 +298,9 @@ def run_benchmark(config: dict) -> dict: print() dataset = load_dataset(max_samples, max_context_tokens) - model, tokenizer, connector, identity = load_model(model_name, device) + # Hybrid mode needs eager attention for output_attentions (SDPA silently ignores it) + attn_impl = "eager" if hybrid_k > 0 else None + model, tokenizer, connector, identity = load_model(model_name, device, attn_implementation=attn_impl) direct_results = None latent_results = None @@ -383,6 +386,7 @@ def run_benchmark(config: dict) -> dict: temperature=temperature, top_p=top_p, verbose=verbose, projection_temperature=projection_temperature, num_transfer_states=num_transfer_states, + hybrid_k=hybrid_k, ) del model_b, tokenizer_b, connector_b, identity_b diff --git a/benchmarks/shared/generation.py b/benchmarks/shared/generation.py index 541fc8e..2682c9b 100644 --- a/benchmarks/shared/generation.py +++ b/benchmarks/shared/generation.py @@ -47,6 +47,7 @@ def generate_text( max_new_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.95, + logits_processor: Optional[Any] = None, ) -> Tuple[str, Optional[Any]]: """Generate text from input_ids, optionally with a pre-filled KV-cache. @@ -81,7 +82,7 @@ def generate_text( ) attention_mask = torch.cat([past_mask, attention_mask], dim=-1) - outputs = model.generate( + gen_kwargs = dict( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, @@ -93,6 +94,9 @@ def generate_text( past_key_values=past_key_values, cache_position=cache_position, ) + if logits_processor is not None: + gen_kwargs["logits_processor"] = logits_processor + outputs = model.generate(**gen_kwargs) generated_ids = outputs.sequences[0, prompt_len:] text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() diff --git a/benchmarks/shared/model_utils.py b/benchmarks/shared/model_utils.py index dfbc78e..7dededb 100644 --- a/benchmarks/shared/model_utils.py +++ b/benchmarks/shared/model_utils.py @@ -30,8 +30,13 @@ def auto_device(device: Optional[str]) -> str: return "cpu" -def load_model(model_name: str, device: str): - """Load model and tokenizer, return (model, tokenizer, connector, identity).""" +def load_model(model_name: str, device: str, attn_implementation: Optional[str] = None): + """Load model and tokenizer, return (model, tokenizer, connector, identity). + + Args: + attn_implementation: Override attention implementation (e.g. "eager" for + output_attentions support — SDPA silently ignores it). + """ from transformers import AutoModelForCausalLM, AutoTokenizer from avp.connectors.huggingface import HuggingFaceConnector @@ -45,7 +50,11 @@ def load_model(model_name: str, device: str): dtype = torch.float16 else: dtype = torch.float32 - model = AutoModelForCausalLM.from_pretrained(model_name, dtype=dtype) + kwargs = {} + if attn_implementation is not None: + kwargs["attn_implementation"] = attn_implementation + print(f" Using attention implementation: {attn_implementation}") + model = AutoModelForCausalLM.from_pretrained(model_name, dtype=dtype, **kwargs) tokenizer = AutoTokenizer.from_pretrained(model_name) model.to(device) model.eval() diff --git a/src/avp/__init__.py b/src/avp/__init__.py index 54dc0f8..06452bf 100644 --- a/src/avp/__init__.py +++ b/src/avp/__init__.py @@ -84,7 +84,9 @@ "validate_projection", "TransferQualityConfig", "TransferQualityResult", + "TaskClassification", "assess_transfer", + "classify_task", } # Transport classes are lazy-loaded because httpx is an optional dependency. @@ -178,7 +180,9 @@ def __getattr__(name: str): "validate_projection", "TransferQualityConfig", "TransferQualityResult", + "TaskClassification", "assess_transfer", + "classify_task", # Errors "AVPError", "IncompatibleModelsError", diff --git a/src/avp/connectors/huggingface.py b/src/avp/connectors/huggingface.py index eb56470..d0090f8 100644 --- a/src/avp/connectors/huggingface.py +++ b/src/avp/connectors/huggingface.py @@ -541,10 +541,13 @@ def generate( context: Optional[AVPContext] = None, source: Optional["HuggingFaceConnector"] = None, cross_model: bool = False, + cross_model_method: str = "rosetta", max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95, do_sample: bool = True, + logit_bias_alpha: float = 0.5, + logit_bias_confidence_threshold: float = 0.8, _diagnostics: Optional[Any] = None, ) -> str: """Generate text, optionally conditioned on latent context from think(). @@ -558,10 +561,19 @@ def generate( Requires ``cross_model=True``. cross_model: Must be True to enable cross-model projection. Cross-model (Rosetta Stone) is experimental. Default False. + cross_model_method: Cross-model method to use. Options: + - ``"rosetta"`` (default): Project hidden state to target embedding + space and prime KV-cache with a single virtual token. + - ``"logit_guided"``: Map source vocabulary distribution to target + vocabulary as additive logit bias during generation. max_new_tokens: Maximum tokens to generate. temperature: Sampling temperature. top_p: Nucleus sampling threshold. do_sample: Whether to use sampling (True) or greedy decoding (False). + logit_bias_alpha: Scaling factor for logit-guided bias (0.0-2.0). + Only used when cross_model_method="logit_guided". Default 0.5. + logit_bias_confidence_threshold: When target model's max softmax + probability exceeds this, skip bias for that step. Default 0.8. _diagnostics: Internal. Pre-created TransferDiagnostics to populate. Returns: @@ -573,6 +585,9 @@ def generate( """ import torch + logit_bias_processor = None + mid_layer_hook_ctx = None + # Cross-model: auto-project when source connector is provided if ( source is not None @@ -590,6 +605,27 @@ def generate( ) context = None source = None + elif cross_model_method == "logit_guided": + logit_bias_processor = self._compute_logit_bias_processor( + source, context, + alpha=logit_bias_alpha, + confidence_threshold=logit_bias_confidence_threshold, + _diagnostics=_diagnostics, + ) + # Don't use context for KV-cache priming in logit-guided mode + context = None + elif cross_model_method == "mid_layer": + mid_layer_hook_ctx = self._prepare_mid_layer_injection( + source, context, _diagnostics=_diagnostics, + ) + # Don't use KV-cache priming — injection happens via hook + context = None + elif cross_model_method == "trained": + mid_layer_hook_ctx = self._prepare_trained_injection( + source, context, _diagnostics=_diagnostics, + ) + # Don't use KV-cache priming — injection happens via per-layer hooks + context = None else: context = self._apply_rosetta_projection( source, context, _diagnostics=_diagnostics, @@ -648,7 +684,16 @@ def generate( if past_kv is not None: gen_kwargs["past_key_values"] = past_kv gen_kwargs["cache_position"] = cache_position - outputs = self.model.generate(**gen_kwargs) + if logit_bias_processor is not None: + from transformers import LogitsProcessorList + gen_kwargs["logits_processor"] = LogitsProcessorList( + [logit_bias_processor] + ) + if mid_layer_hook_ctx is not None: + with mid_layer_hook_ctx: + outputs = self.model.generate(**gen_kwargs) + else: + outputs = self.model.generate(**gen_kwargs) generated_ids = outputs.sequences[0, prompt_len:] text = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip() @@ -673,6 +718,246 @@ def generate( return text + # --- Cross-model logit-guided decoding --- + + def _compute_logit_bias_processor( + self, + source: "HuggingFaceConnector", + context: AVPContext, + alpha: float = 0.5, + confidence_threshold: float = 0.8, + _diagnostics: Optional[Any] = None, + ) -> Any: + """Compute logit bias processor for cross-model guided decoding. + + Uses the source model's last hidden state (from think()) to compute + a vocabulary distribution, maps it through vocab overlap to the target + vocabulary, and returns a LogitsProcessor that applies it as additive + bias during generation. + + Args: + source: Source model connector. + context: AVPContext from source's think() — must have last_hidden_state. + alpha: Bias scaling factor. + confidence_threshold: Skip bias when target is this confident. + _diagnostics: Internal diagnostics object. + + Returns: + CrossModelLogitBias processor instance. + """ + if context.last_hidden_state is None: + raise ValueError( + "Logit-guided decoding requires context with last_hidden_state. " + "Use think() to produce the context." + ) + + avp_map = self._get_or_calibrate_map(source) + + # Get source lm_head + source_lm_head = source.model.get_output_embeddings() + if source_lm_head is None: + source_lm_head = getattr(source.model, "lm_head", None) + if source_lm_head is None or not hasattr(source_lm_head, "weight"): + raise RealignmentError( + "Cannot get source output embeddings (lm_head) for " + "logit-guided decoding." + ) + + from ..rosetta.logit_guided import ( + CrossModelLogitBias, + compute_cross_model_logit_bias, + ) + + target_vocab_size = self.model.config.vocab_size + bias = compute_cross_model_logit_bias( + source_hidden_state=context.last_hidden_state, + source_lm_head_weight=source_lm_head.weight, + avp_map=avp_map, + target_vocab_size=target_vocab_size, + ) + + if _diagnostics is not None: + method = getattr(avp_map, "method", None) + method_name = method.name if hasattr(method, "name") else str(method or "") + _diagnostics.transfer_mode = "logit_guided" + _diagnostics.projection_method = method_name + + logger.info( + "Logit-guided decoding: alpha=%.2f, confidence_threshold=%.2f, " + "bias nonzero=%d/%d", + alpha, confidence_threshold, + int((bias != 0).sum()), target_vocab_size, + ) + + return CrossModelLogitBias( + bias=bias, + alpha=alpha, + confidence_threshold=confidence_threshold, + ) + + # --- Cross-model mid-layer injection --- + + def _prepare_mid_layer_injection( + self, + source: "HuggingFaceConnector", + context: AVPContext, + _diagnostics: Optional[Any] = None, + ) -> Any: + """Prepare mid-layer injection hook for cross-model generation. + + Projects the source hidden state to target space (same as rosetta) + but installs a forward hook to inject at an intermediate layer + (~75% depth) instead of priming KV-cache at layer 0. + + Args: + source: Source model connector. + context: AVPContext from source's think(). + _diagnostics: Internal diagnostics object. + + Returns: + Context manager (mid_layer_injection_hook) ready to wrap generate(). + """ + if context.last_hidden_state is None: + raise ValueError( + "Mid-layer injection requires context with last_hidden_state. " + "Use think() to produce the context." + ) + + avp_map = self._get_or_calibrate_map(source) + + # Project source hidden state to target space + use_metrics = _diagnostics is not None + result = source.project_hidden_for_cross_model( + context.last_hidden_state, avp_map, + return_metrics=use_metrics, + ) + if use_metrics and isinstance(result, tuple): + projected, proj_metrics = result + if _diagnostics is not None and proj_metrics: + _diagnostics.hidden_state_norm = ( + float(proj_metrics["hidden_state_norm"].mean()) + if "hidden_state_norm" in proj_metrics else None + ) + _diagnostics.nearest_cos_sim = ( + float(proj_metrics["nearest_cos_sim"].mean()) + if "nearest_cos_sim" in proj_metrics else None + ) + else: + projected = result + + # Ensure [1, D] shape + if projected.dim() == 1: + projected = projected.unsqueeze(0) + if projected.dim() == 3: + projected = projected.squeeze(0) # [1, N, D] -> [N, D], take last + projected = projected[-1:, :] + + from ..rosetta.mid_layer import ( + compute_injection_layer, + mid_layer_injection_hook, + ) + + # Compute injection point + target_num_layers = self._identity.num_layers + injection_layer = compute_injection_layer(target_num_layers) + + method = getattr(avp_map, "method", None) + method_name = method.name if hasattr(method, "name") else str(method or "") + + if _diagnostics is not None: + _diagnostics.transfer_mode = "mid_layer" + _diagnostics.projection_method = method_name + + logger.info( + "Mid-layer injection: target layer %d/%d (%.0f%% depth), " + "projected shape %s", + injection_layer, target_num_layers, + 100 * injection_layer / target_num_layers, + tuple(projected.shape), + ) + + # Return the context manager — caller wraps model.generate() with it + return mid_layer_injection_hook( + model=self.model, + injection_layer=injection_layer, + projected_hidden=projected, + ) + + # --- Cross-model trained projection --- + + def _prepare_trained_injection( + self, + source: "HuggingFaceConnector", + context: AVPContext, + _diagnostics: Optional[Any] = None, + ) -> Any: + """Prepare per-layer trained projection hooks for cross-model generation. + + Uses a trained AVPMap with per-layer linear projections + learned gates. + Each layer's forward hook adds the projected source hidden state + (scaled by the learned gate) to the target layer's output. + + Args: + source: Source model connector. + context: AVPContext from source's think(). + _diagnostics: Internal diagnostics object. + + Returns: + Context manager (trained_multi_layer_hook) ready to wrap generate(). + """ + import torch + + if context.last_hidden_state is None: + raise ValueError( + "Trained projection requires context with last_hidden_state. " + "Use think() to produce the context." + ) + + avp_map = self._get_or_calibrate_map(source) + + if avp_map.layer_weights is None: + raise ValueError( + "AVPMap does not contain trained per-layer projections. " + "Use train_projector() to train a projection first, or " + "use cross_model_method='rosetta' for zero-training projection." + ) + + # Source hidden state [1, D_src] or [D_src] + src_hidden = context.last_hidden_state.float() + if src_hidden.dim() == 1: + src_hidden = src_hidden.unsqueeze(0) # [1, D_src] + + # Pre-compute per-layer projections + layer_projections = [] + active_count = 0 + for i, (w, b, gate) in enumerate( + zip(avp_map.layer_weights, avp_map.layer_biases, avp_map.layer_gates) + ): + if gate < 0.01: + layer_projections.append(None) + continue + # w: [D_tgt, D_src], b: [D_tgt] + projected = torch.nn.functional.linear( + src_hidden, w.to(src_hidden.device), b.to(src_hidden.device) + ) # [1, D_tgt] + layer_projections.append((projected, gate)) + active_count += 1 + + if _diagnostics is not None: + _diagnostics.transfer_mode = "trained" + _diagnostics.projection_method = "TRAINED" + + logger.info( + "Trained injection: %d/%d active layers (gate > 0.01)", + active_count, len(avp_map.layer_gates), + ) + + from ..rosetta.trained_hooks import trained_multi_layer_hook + return trained_multi_layer_hook( + model=self.model, + layer_projections=layer_projections, + ) + # --- Cross-model rosetta projection --- def _apply_rosetta_projection( diff --git a/src/avp/easy.py b/src/avp/easy.py index 1ec4a2f..b4c4a54 100644 --- a/src/avp/easy.py +++ b/src/avp/easy.py @@ -252,6 +252,7 @@ def generate( model: str, source_model: Optional[str] = None, cross_model: bool = False, + cross_model_method: str = "rosetta", steps: int = 20, context: Optional["AVPContext"] = None, store: Optional[Any] = None, @@ -259,6 +260,7 @@ def generate( prior_key: Optional[str] = None, max_new_tokens: int = 512, temperature: float = 0.7, + logit_bias_alpha: float = 0.5, collect_metrics: bool = False, debug_config: Optional["DebugConfig"] = None, ) -> Union[str, Tuple[str, "GenerateMetrics"]]: @@ -372,8 +374,10 @@ def generate( context=source_context, source=source_connector, cross_model=True, + cross_model_method=cross_model_method, max_new_tokens=max_new_tokens, temperature=temperature, + logit_bias_alpha=logit_bias_alpha, _diagnostics=diagnostics, ) generate_duration = _time.perf_counter() - t_gen diff --git a/src/avp/rosetta/__init__.py b/src/avp/rosetta/__init__.py index 6cb3c9d..417ea14 100644 --- a/src/avp/rosetta/__init__.py +++ b/src/avp/rosetta/__init__.py @@ -6,7 +6,7 @@ from .calibrate import AVPMap, DEFAULT_ANCHORS, calibrate from .project import apply_cross_model_projection, vocab_overlap_projection, vocabulary_mediated_projection -from .quality import TransferQualityConfig, TransferQualityResult, assess_transfer +from .quality import TaskClassification, TransferQualityConfig, TransferQualityResult, assess_transfer, classify_task from .registry import find_map, load_map, map_id, save_map from .validate import ValidationConfig, ValidationResult, validate_projection @@ -22,7 +22,9 @@ # Quality gate "TransferQualityConfig", "TransferQualityResult", + "TaskClassification", "assess_transfer", + "classify_task", # Registry "save_map", "load_map", diff --git a/src/avp/rosetta/calibrate.py b/src/avp/rosetta/calibrate.py index 9710489..160ed93 100644 --- a/src/avp/rosetta/calibrate.py +++ b/src/avp/rosetta/calibrate.py @@ -42,6 +42,10 @@ class AVPMap: tgt_indices: Optional[Any] = None # LongTensor [N_shared] — target token IDs overlap_count: int = 0 overlap_ratio: float = 0.0 + # Trained translator fields (per-layer projections) + layer_weights: Optional[List[Any]] = None # List of Tensor [D_src, D_tgt] per layer + layer_biases: Optional[List[Any]] = None # List of Tensor [D_tgt] per layer + layer_gates: Optional[List[float]] = None # List of float gate values per layer def __post_init__(self) -> None: if isinstance(self.method, str): diff --git a/src/avp/rosetta/logit_guided.py b/src/avp/rosetta/logit_guided.py new file mode 100644 index 0000000..847aa5b --- /dev/null +++ b/src/avp/rosetta/logit_guided.py @@ -0,0 +1,134 @@ +"""Logit-guided decoding for cross-model communication. + +Instead of compressing source model information into a single virtual token +(standard rosetta), distributes the signal across the target model's entire +autoregressive generation as additive logit biases. + +The source model's vocabulary distribution (from think()) is mapped through +vocabulary overlap to the target model's vocabulary and applied as a constant +bias during generation. +""" + +from typing import Any, Optional + +from .._torch_compat import require_torch as _require_torch + + +class CrossModelLogitBias: + """HuggingFace LogitsProcessor that applies cross-model logit bias. + + During each generation step, adds a scaled bias vector to the target model's + logits. Implements confidence gating: when the target model is already highly + confident (max probability > threshold), the bias is suppressed to avoid + pushing the model away from correct predictions. + + Compatible with transformers LogitsProcessor protocol (__call__ signature). + """ + + def __init__( + self, + bias: Any, + alpha: float = 0.5, + confidence_threshold: float = 0.8, + ): + """Initialize logit bias processor. + + Args: + bias: Tensor [target_vocab_size] — additive bias for target logits. + Should be zero-mean over mapped tokens. + alpha: Scaling factor for the bias. 0.5 = conservative (recommended + for cross-vocab mapping). Higher = stronger source influence. + confidence_threshold: When target's max softmax probability exceeds + this value, the bias is suppressed for that step. Prevents + "obvious blindness" (biasing away from correct predictions). + """ + self.bias = bias + self.alpha = alpha + self.confidence_threshold = confidence_threshold + + def __call__(self, input_ids: Any, scores: Any) -> Any: + """Apply cross-model logit bias with confidence gating. + + Args: + input_ids: [batch, seq_len] — generated token IDs so far. + scores: [batch, vocab_size] — target model logits for current step. + + Returns: + Modified logits [batch, vocab_size]. + """ + torch = _require_torch() + + bias = self.bias.to(device=scores.device, dtype=scores.dtype) + + # Confidence gating: skip bias when target is already confident + with torch.no_grad(): + probs = torch.softmax(scores, dim=-1) + max_prob = probs.max(dim=-1).values # [batch] + + # Per-batch-element mask: 1.0 if uncertain, 0.0 if confident + mask = (max_prob < self.confidence_threshold).unsqueeze(-1).float() + + return scores + self.alpha * bias * mask + + +def compute_cross_model_logit_bias( + source_hidden_state: Any, + source_lm_head_weight: Any, + avp_map: Any, + target_vocab_size: int, + temperature: float = 1.0, +) -> Any: + """Compute logit bias vector for cross-model guided decoding. + + Takes the source model's last hidden state (from think()), computes its + vocabulary distribution, and maps it through vocab overlap to the target + model's vocabulary as an additive bias. + + Args: + source_hidden_state: Tensor [1, D_src] or [D_src] — last hidden state + from source model's think(). + source_lm_head_weight: Source model's lm_head weight [vocab_size_src, D_src]. + avp_map: AVPMap with src_indices, tgt_indices, and method. + target_vocab_size: Target model's vocabulary size. + temperature: Softmax temperature for computing source distribution. + Lower = sharper bias. Default 1.0. + + Returns: + Tensor [target_vocab_size] — zero-mean additive bias for target logits. + """ + torch = _require_torch() + from ..types import ProjectionMethod + + h = source_hidden_state.detach().to(torch.float32) + if h.dim() == 2: + h = h.squeeze(0) # [D_src] + + w_src = source_lm_head_weight.detach().to(device=h.device, dtype=torch.float32) + + # Compute source log-probabilities + source_logits = torch.matmul(h, w_src.T) # [vocab_size_src] + source_log_probs = torch.log_softmax(source_logits / temperature, dim=-1) + + # Initialize target bias to zero (unmapped tokens get no bias) + target_bias = torch.zeros(target_vocab_size, device=h.device, dtype=torch.float32) + + if avp_map.method == ProjectionMethod.VOCAB_OVERLAP: + src_idx = avp_map.src_indices.to(h.device) + tgt_idx = avp_map.tgt_indices.to(h.device) + target_bias[tgt_idx] = source_log_probs[src_idx] + elif avp_map.method == ProjectionMethod.VOCAB_MEDIATED: + # Same tokenizer — direct 1:1 mapping + shared_vocab = min(source_log_probs.shape[0], target_vocab_size) + target_bias[:shared_vocab] = source_log_probs[:shared_vocab] + else: + # Ridge/Procrustes — no token-level mapping available + # Fall back to no bias (caller should use rosetta instead) + return target_bias + + # Zero-mean the mapped entries so the bias doesn't shift the distribution's + # center of mass (only nudges relative token preferences) + nonzero_mask = target_bias != 0.0 + if nonzero_mask.any(): + target_bias[nonzero_mask] -= target_bias[nonzero_mask].mean() + + return target_bias diff --git a/src/avp/rosetta/mid_layer.py b/src/avp/rosetta/mid_layer.py new file mode 100644 index 0000000..876ce19 --- /dev/null +++ b/src/avp/rosetta/mid_layer.py @@ -0,0 +1,210 @@ +"""Mid-layer injection for cross-model latent transfer. + +Instead of injecting projected hidden states at layer 0 (via inputs_embeds), +injects at an intermediate layer (~75% depth) using a forward hook. This +bypasses the early embedding/position-encoding layers and operates directly +in the semantic representation space. + +Based on: +- Ramesh & Li (2501.14082): Cross-model hidden state injection at intermediate + layers, up to 27% improvement over text, cross-family confirmed. +- Proportional depth mapping (2504.08775): Layer L_a/N_a maps to L_b/N_b + across architectures (p < 0.005 for 24 LLMs from 1B-70B). + +Key design decisions: +- REPLACE, not sum/mean (Ramesh & Li found sum/mean produce OOD norms) +- Proportional depth mapping: injection_layer = int(N_tgt * extraction_ratio) +- Forward hook scoped to prefill only (removed after first forward pass) +""" + +import logging +from contextlib import contextmanager +from typing import Any, Optional, Tuple + +logger = logging.getLogger(__name__) + +# Default extraction/injection depth ratio (validated at 0.75 = ~75% depth) +DEFAULT_DEPTH_RATIO = 0.75 + + +def compute_extraction_layer(num_layers: int, depth_ratio: float = DEFAULT_DEPTH_RATIO) -> int: + """Compute the layer index to extract hidden states from. + + Args: + num_layers: Total number of transformer layers in the model. + depth_ratio: Fraction of depth to extract from (0.0=first, 1.0=last). + + Returns: + Layer index (0-indexed). + """ + layer = int(num_layers * depth_ratio) + return min(layer, num_layers - 1) + + +def compute_injection_layer(num_layers: int, depth_ratio: float = DEFAULT_DEPTH_RATIO) -> int: + """Compute the layer index to inject hidden states into. + + Uses proportional depth mapping: if source extracted from 75% depth, + inject at 75% depth of target model (even if different number of layers). + + Args: + num_layers: Total number of transformer layers in target model. + depth_ratio: Fraction of depth to inject at. + + Returns: + Layer index (0-indexed). + """ + layer = int(num_layers * depth_ratio) + return min(layer, num_layers - 1) + + +def extract_mid_layer_hidden( + model_outputs: Any, + extraction_layer: int, +) -> Any: + """Extract hidden state from an intermediate layer of model outputs. + + Args: + model_outputs: Model output dict with hidden_states. + extraction_layer: Layer index to extract from. + + Returns: + Hidden state tensor [B, D] from the specified layer's last token. + """ + # hidden_states is a tuple of (num_layers + 1) tensors, each [B, seq, D] + # Index 0 = embedding output, index i = output of layer i + hidden_states = model_outputs.hidden_states + if extraction_layer + 1 >= len(hidden_states): + extraction_layer = len(hidden_states) - 2 # -1 is last layer output + # +1 because index 0 is embedding layer output, index 1 is layer 0 output + layer_hidden = hidden_states[extraction_layer + 1] + return layer_hidden[:, -1, :] # [B, D] — last token only + + +def _get_decoder_layers(model: Any): + """Get the list of decoder layers from a HuggingFace model. + + Handles different model architectures (Llama, Qwen, GPT-2, etc.). + """ + # Try common attribute paths + inner = getattr(model, "model", None) + if inner is not None: + layers = getattr(inner, "layers", None) + if layers is not None: + return layers + + # GPT-2 style + transformer = getattr(model, "transformer", None) + if transformer is not None: + h = getattr(transformer, "h", None) + if h is not None: + return h + + raise AttributeError( + f"Cannot find decoder layers in model {type(model).__name__}. " + "Expected model.model.layers or model.transformer.h" + ) + + +@contextmanager +def mid_layer_injection_hook( + model: Any, + injection_layer: int, + projected_hidden: Any, +): + """Context manager that installs a forward hook to replace hidden states + at a specific layer during the first forward pass (prefill). + + The hook fires once and then removes itself, so it only affects the + initial prefill pass, not subsequent autoregressive generation steps. + + Args: + model: HuggingFace model to hook into. + injection_layer: Layer index to inject at. + projected_hidden: Tensor [1, D] or [B, D] to replace the last token's + hidden state with. + + Yields: + None. The hook is active during the context. + """ + import torch + + layers = _get_decoder_layers(model) + target_layer = layers[injection_layer] + + fired = [False] # mutable flag for closure + + def hook_fn(module, input, output): + if fired[0]: + return output + + fired[0] = True + + # Decoder layer output is a tuple: (hidden_states, ...) or just hidden_states + if isinstance(output, tuple): + hidden = output[0] # [B, seq_len, D] + else: + hidden = output + + # Replace last token's hidden state with projected source hidden state + injection = projected_hidden.to(device=hidden.device, dtype=hidden.dtype) + if injection.dim() == 1: + injection = injection.unsqueeze(0) # [D] -> [1, D] + + # Clone to avoid in-place modification + modified = hidden.clone() + modified[:, -1, :] = injection # Replace last position + + if isinstance(output, tuple): + return (modified,) + output[1:] + return modified + + handle = target_layer.register_forward_hook(hook_fn) + try: + yield + finally: + handle.remove() + + +def project_for_mid_layer( + source_hidden: Any, + avp_map: Any, + source_model: Any, + target_model: Any, + target_num_layers: int, + injection_depth_ratio: float = DEFAULT_DEPTH_RATIO, +) -> Tuple[Any, int]: + """Project source hidden state for mid-layer injection. + + Unlike rosetta (which projects to target embedding space for layer-0 inputs_embeds), + mid-layer projects to the target's intermediate representation space. Since we don't + have a direct map between intermediate spaces, we use the same vocab-mediated/overlap + projection but normalize to the target layer's activation norm instead of the + embedding norm. + + Args: + source_hidden: Source hidden state [1, D_src] or [D_src]. + avp_map: AVPMap with projection data. + source_model: Source HuggingFace model. + target_model: Target HuggingFace model. + target_num_layers: Number of layers in target model. + injection_depth_ratio: Depth ratio for injection point. + + Returns: + Tuple of (projected_hidden [1, D_tgt], injection_layer_index). + """ + import torch + from .project import apply_cross_model_projection + + injection_layer = compute_injection_layer(target_num_layers, injection_depth_ratio) + + # Use standard vocab-mediated/overlap projection + projected = apply_cross_model_projection( + source_hidden, avp_map, source_model, target_model, + ) + + # Ensure correct shape [1, D] + if projected.dim() == 1: + projected = projected.unsqueeze(0) + + return projected, injection_layer diff --git a/src/avp/rosetta/quality.py b/src/avp/rosetta/quality.py index 2ab6f0e..d4ab40f 100644 --- a/src/avp/rosetta/quality.py +++ b/src/avp/rosetta/quality.py @@ -5,12 +5,17 @@ Primary signal: source prompt token count. Single-embedding rosetta works for short structured prompts (<300 tokens) but degrades significantly for -longer prompts. Validated across 4 benchmarks × 3 rosetta configurations: - - GSM8K cross-family: 65% at <300 tokens → 41% at 300-500 tokens - - HumanEval same-family: 61% at <300 → 40% at 300-500 → 19% at 500+ - - Even reverse rosetta (strong config): 87% → 84% → 55% at 500+ +longer prompts. Validated across 4 benchmarks x 3 rosetta configurations: + - GSM8K cross-family: 65% at <300 tokens, 41% at 300-500 tokens + - HumanEval same-family: 61% at <300, 40% at 300-500, 19% at 500+ + - Even reverse rosetta (strong config): 87%, 84%, 55% at 500+ Prompt length is a zero-cost proxy for information density. +Enhanced signal (v2): task-type classification via prompt features. Detects +math/code markers (structured tasks where rosetta works) vs comprehension +patterns (multi-paragraph context + question where text is needed). Zero +latency overhead, zero dependencies. See classify_task(). + Secondary signal (opt-in): effective rank ratio of hidden states. High effective rank means information is spread across many dimensions, which a single embedding cannot capture. Off by default; included for future @@ -20,7 +25,12 @@ from avp.rosetta.quality import assess_transfer + # Token-count only (backward compatible) result = assess_transfer(prompt_tokens=len(input_ids[0])) + + # Enhanced: with prompt text for task-type classification + result = assess_transfer(prompt_tokens=len(input_ids[0]), prompt_text=prompt) + if result.recommend_latent: # proceed with rosetta projection ... @@ -29,10 +39,47 @@ ... """ -from dataclasses import dataclass +import re +from dataclasses import dataclass, field from typing import Any, Optional +# Compiled patterns for task classification (module-level for zero overhead) +_MATH_MARKERS = re.compile( + r'\\boxed|\\frac|\\sqrt|\\sum|\\int|\\cdot|' + r'\d+\s*[+\-*/=]\s*\d+|' + r'\$[^$]+\$' +) +_CODE_MARKERS = re.compile( + r'\bdef\s+\w+|' + r'\bclass\s+\w+|' + r'\bimport\s+\w+|' + r'\bfunction\s+\w+|' + r'\breturn\s+|' + r'```|' + r'>>>|' + r'\bfor\s+\w+\s+in\s+' +) +_COMPREHENSION_STARTERS = re.compile( + r'(?:^|[\n.?!]\s*)(?:who\s+(?:is|was|were|are|did)|' + r'what\s+(?:is|was|were|are|did|does)|' + r'when\s+(?:did|was|were|is)|' + r'where\s+(?:did|was|were|is)|' + r'why\s+(?:did|was|were|is)|' + r'how\s+did|' + r'according\s+to|based\s+on|which\s+of\s+the\s+following|' + r'in\s+the\s+(?:passage|text|article|context))', + re.IGNORECASE, +) +_STRUCTURED_STARTERS = re.compile( + r'(?:^|[\n.]\s*)(?:solve|compute|calculate|find\s+the|evaluate|simplify|' + r'how\s+much|how\s+many|' + r'implement|write\s+(?:a\s+)?(?:function|program|code|class)|' + r'debug|fix\s+(?:the|this)|refactor)', + re.IGNORECASE, +) + + @dataclass class TransferQualityConfig: """Configuration for the per-transfer quality gate. @@ -42,6 +89,9 @@ class TransferQualityConfig: transfer is recommended. Default 300, validated across GSM8K, HumanEval, and HotpotQA rosetta benchmarks. Above 300 tokens, accuracy drops 20-30pp across all configurations. + use_task_classification: Whether to use prompt text features + for task-type classification. Requires prompt_text to be + passed to assess_transfer(). On by default. check_effective_rank: Whether to compute effective rank ratio of hidden states as a secondary signal. Requires torch and a hidden_states tensor. Off by default. @@ -51,10 +101,26 @@ class TransferQualityConfig: """ max_prompt_tokens: int = 300 + use_task_classification: bool = True check_effective_rank: bool = False max_effective_rank_ratio: float = 0.8 +@dataclass +class TaskClassification: + """Result of prompt task-type classification. + + Attributes: + task_type: 'structured' or 'comprehension'. + score: Numeric score (positive=structured, negative=comprehension). + features: Dict of individual feature contributions. + """ + + task_type: str + score: int + features: dict = field(default_factory=dict) + + @dataclass class TransferQualityResult: """Result of a per-transfer quality assessment. @@ -66,12 +132,112 @@ class TransferQualityResult: reason: Human-readable explanation of the recommendation. effective_rank_ratio: Effective rank ratio of hidden states, or None if not computed. + task_classification: Task-type classification result, or None + if prompt_text was not provided. """ recommend_latent: bool prompt_tokens: int reason: str effective_rank_ratio: Optional[float] = None + task_classification: Optional[TaskClassification] = None + + +def classify_task(prompt_text: str, prompt_tokens: int = 0) -> TaskClassification: + """Classify a prompt as 'structured' or 'comprehension'. + + Uses lexical and structural features of the prompt text to determine + whether rosetta projection is likely to work (structured tasks like + math/code) or whether text mode is needed (comprehension tasks with + long contexts). + + Scoring: positive = structured, negative = comprehension. + Threshold: score > 0 = structured. + + Features (5 signals, validated against AVP benchmark data): + 1. Token count: >500 penalizes, <200 rewards + 2. Digit density: high digit ratio = math/structured + 3. Math/code markers: regex patterns for code/math syntax + 4. Comprehension question patterns: who/what/when/according-to + 5. Multi-paragraph context: 3+ paragraphs = comprehension + + Args: + prompt_text: The raw prompt text. + prompt_tokens: Token count (if known, used for length signal). + + Returns: + TaskClassification with task_type, score, and feature breakdown. + """ + score = 0 + features = {} + + # Feature 1: Token count (existing signal, enhanced) + if prompt_tokens > 500: + features["token_count"] = -2 + score -= 2 + elif prompt_tokens > 300: + features["token_count"] = -1 + score -= 1 + elif 0 < prompt_tokens < 200: + features["token_count"] = 1 + score += 1 + else: + features["token_count"] = 0 + + # Feature 2: Digit density + digits = sum(c.isdigit() for c in prompt_text) + text_len = max(len(prompt_text), 1) + digit_ratio = digits / text_len + if digit_ratio > 0.05: + features["digit_density"] = 2 + score += 2 + elif digit_ratio > 0.02: + features["digit_density"] = 1 + score += 1 + else: + features["digit_density"] = 0 + + # Feature 3: Math/code markers + math_hits = len(_MATH_MARKERS.findall(prompt_text)) + code_hits = len(_CODE_MARKERS.findall(prompt_text)) + marker_score = 0 + if math_hits >= 2: + marker_score += 2 + elif math_hits >= 1: + marker_score += 1 + if code_hits >= 2: + marker_score += 2 + elif code_hits >= 1: + marker_score += 1 + marker_score = min(marker_score, 3) # cap at 3 + features["markers"] = marker_score + score += marker_score + + # Feature 4: Comprehension question patterns + comp_hits = len(_COMPREHENSION_STARTERS.findall(prompt_text)) + struct_hits = len(_STRUCTURED_STARTERS.findall(prompt_text)) + if comp_hits > 0 and struct_hits == 0: + features["question_type"] = -2 + score -= 2 + elif struct_hits > 0 and comp_hits == 0: + features["question_type"] = 1 + score += 1 + else: + features["question_type"] = 0 + + # Feature 5: Multi-paragraph context (3+ double-newlines) + paragraphs = prompt_text.count("\n\n") + if paragraphs >= 3: + features["paragraphs"] = -2 + score -= 2 + elif paragraphs >= 2: + features["paragraphs"] = -1 + score -= 1 + else: + features["paragraphs"] = 0 + + task_type = "structured" if score > 0 else "comprehension" + return TaskClassification(task_type=task_type, score=score, features=features) def _compute_effective_rank_ratio(hidden_states: Any) -> float: @@ -94,7 +260,7 @@ def _compute_effective_rank_ratio(hidden_states: Any) -> float: if not isinstance(t, torch.Tensor): t = torch.tensor(t, dtype=torch.float32) - # Squeeze batch dim if present: [1, seq, D] → [seq, D] + # Squeeze batch dim if present: [1, seq, D] -> [seq, D] if t.ndim == 3: t = t.squeeze(0) @@ -109,7 +275,7 @@ def _compute_effective_rank_ratio(hidden_states: Any) -> float: if s_sum == 0: return 0.0 - # Normalized singular values → probability distribution + # Normalized singular values -> probability distribution p = s / s_sum # Shannon entropy of the distribution @@ -127,15 +293,23 @@ def _compute_effective_rank_ratio(hidden_states: Any) -> float: def assess_transfer( prompt_tokens: int, + prompt_text: Optional[str] = None, hidden_states: Any = None, config: Optional[TransferQualityConfig] = None, ) -> TransferQualityResult: """Assess whether a cross-model transfer should use latent or JSON. - This is advisory — the caller decides how to act on the result. + This is advisory -- the caller decides how to act on the result. + + When prompt_text is provided and config.use_task_classification is True, + task-type classification enhances the token-count heuristic. A prompt + classified as 'comprehension' recommends text even if under the token + limit; a prompt classified as 'structured' with strong signals may + recommend latent even if slightly over the token limit. Args: prompt_tokens: Number of tokens in the source prompt. + prompt_text: Optional raw prompt text for task-type classification. hidden_states: Optional tensor of hidden states, shape [seq_len, D] or [1, seq_len, D]. Only used when config.check_effective_rank=True. @@ -148,19 +322,67 @@ def assess_transfer( config = TransferQualityConfig() effective_rank_ratio: Optional[float] = None + task_cls: Optional[TaskClassification] = None - # Primary gate: prompt token count - if prompt_tokens > config.max_prompt_tokens: - return TransferQualityResult( - recommend_latent=False, - prompt_tokens=prompt_tokens, - reason=( - f"prompt_tokens={prompt_tokens} exceeds " - f"max_prompt_tokens={config.max_prompt_tokens}; " - f"single embedding unlikely to capture sufficient information" - ), - effective_rank_ratio=None, - ) + # Task classification (if prompt text available) + if prompt_text is not None and config.use_task_classification: + task_cls = classify_task(prompt_text, prompt_tokens) + + # Combined gate: token count + task classification + if task_cls is not None: + # Strong comprehension signal overrides even short prompts + if task_cls.task_type == "comprehension" and task_cls.score <= -3: + return TransferQualityResult( + recommend_latent=False, + prompt_tokens=prompt_tokens, + reason=( + f"task classified as comprehension (score={task_cls.score}); " + f"single embedding unlikely to capture context" + ), + task_classification=task_cls, + ) + + # Strong structured signal allows slightly longer prompts + if task_cls.task_type == "structured" and task_cls.score >= 3: + # Allow up to 1.5x the normal token limit for strongly structured + extended_limit = int(config.max_prompt_tokens * 1.5) + if prompt_tokens > extended_limit: + return TransferQualityResult( + recommend_latent=False, + prompt_tokens=prompt_tokens, + reason=( + f"prompt_tokens={prompt_tokens} exceeds extended limit " + f"{extended_limit} (structured task, score={task_cls.score})" + ), + task_classification=task_cls, + ) + # Within extended limit -- recommend latent + else: + # Moderate signals: use standard token limit + if prompt_tokens > config.max_prompt_tokens: + return TransferQualityResult( + recommend_latent=False, + prompt_tokens=prompt_tokens, + reason=( + f"prompt_tokens={prompt_tokens} exceeds " + f"max_prompt_tokens={config.max_prompt_tokens}; " + f"task_type={task_cls.task_type} (score={task_cls.score})" + ), + task_classification=task_cls, + ) + else: + # No task classification -- fall back to token count only + if prompt_tokens > config.max_prompt_tokens: + return TransferQualityResult( + recommend_latent=False, + prompt_tokens=prompt_tokens, + reason=( + f"prompt_tokens={prompt_tokens} exceeds " + f"max_prompt_tokens={config.max_prompt_tokens}; " + f"single embedding unlikely to capture sufficient information" + ), + effective_rank_ratio=None, + ) # Secondary gate: effective rank (opt-in) if config.check_effective_rank and hidden_states is not None: @@ -175,6 +397,7 @@ def assess_transfer( f"information too spread for single-embedding transfer" ), effective_rank_ratio=effective_rank_ratio, + task_classification=task_cls, ) return TransferQualityResult( @@ -182,4 +405,5 @@ def assess_transfer( prompt_tokens=prompt_tokens, reason="transfer within quality thresholds", effective_rank_ratio=effective_rank_ratio, + task_classification=task_cls, ) diff --git a/src/avp/rosetta/registry.py b/src/avp/rosetta/registry.py index 8ef67e8..72b4fec 100644 --- a/src/avp/rosetta/registry.py +++ b/src/avp/rosetta/registry.py @@ -56,6 +56,11 @@ def save_map(avp_map: AVPMap, map_dir: Optional[Path] = None) -> Path: "overlap_count": avp_map.overlap_count, "overlap_ratio": avp_map.overlap_ratio, } + # Trained translator fields (per-layer projections) + if avp_map.layer_weights is not None: + data["layer_weights"] = [w.cpu() for w in avp_map.layer_weights] + data["layer_biases"] = [b.cpu() for b in avp_map.layer_biases] + data["layer_gates"] = avp_map.layer_gates torch.save(data, path) return path @@ -101,6 +106,9 @@ def load_map( tgt_indices=data.get("tgt_indices"), overlap_count=data.get("overlap_count", 0), overlap_ratio=data.get("overlap_ratio", 0.0), + layer_weights=data.get("layer_weights"), + layer_biases=data.get("layer_biases"), + layer_gates=data.get("layer_gates"), ) diff --git a/src/avp/rosetta/train.py b/src/avp/rosetta/train.py new file mode 100644 index 0000000..7cf3c2b --- /dev/null +++ b/src/avp/rosetta/train.py @@ -0,0 +1,515 @@ +"""Trained cross-model translator (C2C) for AVP. + +Trains per-layer linear projections with learned sigmoid gates to map +source model hidden states to target model activation space at every +transformer layer. Both source and target models are frozen; only the +lightweight projector trains. + +Based on: +- C2C (arxiv 2510.03215): Cross-cache fusion, +6-14% over zero-shot +- DroidSpeak: ~11% of layers are critical, gates learn which matter +- Model Stitching (2506.06609): Affine maps between layers, 2K-180K samples + +Usage:: + + from avp.rosetta.train import LayerProjector, train_projector + + projector = train_projector( + source_model=src_model, + target_model=tgt_model, + source_tokenizer=src_tok, + target_tokenizer=tgt_tok, + device="cuda", + ) + # Produces an AVPMap with method=TRAINED +""" + +import logging +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +def _require_torch(): + try: + import torch + return torch + except ImportError: + raise ImportError("torch is required for training. pip install torch") + + +@dataclass +class TrainConfig: + """Configuration for training a cross-model projector. + + Attributes: + num_samples: Number of text samples to train on. + batch_size: Training batch size. + learning_rate: Adam learning rate. + num_epochs: Number of training epochs. + gate_reg_weight: L1 regularization weight for gate sparsity. + gate_init: Initial gate logit (sigmoid(gate_init) = initial gate value). + max_seq_len: Maximum sequence length for training samples. + extraction_layer_ratio: Depth ratio for source hidden state extraction. + warmup_steps: Number of linear warmup steps. + seed: Random seed for reproducibility. + """ + + num_samples: int = 5000 + batch_size: int = 4 + learning_rate: float = 1e-4 + num_epochs: int = 2 + gate_reg_weight: float = 0.01 + gate_init: float = -3.0 # sigmoid(-3) ~ 0.05 — warm enough for gradient signal + max_seq_len: int = 256 + extraction_layer_ratio: float = 0.75 + warmup_steps: int = 100 + seed: int = 42 + mse_aux_weight: float = 0.1 # Weight for MSE auxiliary loss (0 = NTP only) + use_ntp_loss: bool = False # MSE-only is sufficient (NTP adds no accuracy, doubles compute) + + +class LayerProjector: + """Per-layer linear projections with learned sigmoid gates. + + Maps a source hidden state to target activation space at each layer. + Gate values control injection strength per layer (initialized near zero). + """ + + def __init__(self, source_dim: int, target_dim: int, num_layers: int, config: Optional[TrainConfig] = None): + torch = _require_torch() + import torch.nn as nn + + if config is None: + config = TrainConfig() + + self.source_dim = source_dim + self.target_dim = target_dim + self.num_layers = num_layers + + # Per-layer projection: source hidden -> target hidden at layer L + self.layer_projections = nn.ModuleList([ + nn.Linear(source_dim, target_dim, bias=True) + for _ in range(num_layers) + ]) + + # Per-layer gate: sigmoid scalar controlling injection strength + self.layer_gates = nn.ParameterList([ + nn.Parameter(torch.tensor(config.gate_init)) + for _ in range(num_layers) + ]) + + # Initialize projections with small weights + for proj in self.layer_projections: + nn.init.xavier_uniform_(proj.weight, gain=0.1) + nn.init.zeros_(proj.bias) + + def parameters(self): + """Return all trainable parameters.""" + for proj in self.layer_projections: + yield from proj.parameters() + for gate in self.layer_gates: + yield gate + + def to(self, device): + """Move all parameters to device.""" + for proj in self.layer_projections: + proj.to(device) + for i, gate in enumerate(self.layer_gates): + self.layer_gates[i] = gate.to(device) + return self + + def train(self): + """Set to training mode.""" + for proj in self.layer_projections: + proj.train() + + def eval(self): + """Set to evaluation mode.""" + for proj in self.layer_projections: + proj.eval() + + def forward(self, source_hidden, return_gate_tensors: bool = False): + """Project source hidden state to each target layer. + + Args: + source_hidden: [B, D_src] from source model. + return_gate_tensors: If True, return gate as tensor (for training + gradient flow). If False, return gate as float (for inference). + + Returns: + List of (projected_hidden [B, D_tgt], gate_value) per layer. + gate_value is a float (inference) or tensor (training). + """ + torch = _require_torch() + results = [] + for proj, gate_logit in zip(self.layer_projections, self.layer_gates): + projected = proj(source_hidden) # [B, D_tgt] + gate_tensor = torch.sigmoid(gate_logit) # scalar tensor in [0, 1] + if return_gate_tensors: + results.append((projected, gate_tensor)) + else: + results.append((projected, gate_tensor.item())) + return results + + def get_active_layers(self, threshold: float = 0.01) -> List[int]: + """Return indices of layers with gate > threshold.""" + torch = _require_torch() + active = [] + for i, gate_logit in enumerate(self.layer_gates): + gate = torch.sigmoid(gate_logit).item() + if gate > threshold: + active.append(i) + return active + + def export_weights(self) -> Tuple[List[Any], List[Any], List[float]]: + """Export weights for serialization into AVPMap. + + Returns: + Tuple of (layer_weights, layer_biases, layer_gates) + where each is a list of length num_layers. + """ + torch = _require_torch() + weights = [] + biases = [] + gates = [] + for proj, gate_logit in zip(self.layer_projections, self.layer_gates): + weights.append(proj.weight.detach().cpu()) # [D_tgt, D_src] + biases.append(proj.bias.detach().cpu()) # [D_tgt] + gates.append(torch.sigmoid(gate_logit).item()) # float + return weights, biases, gates + + +def _load_training_data( + tokenizer: Any, + num_samples: int, + max_seq_len: int, + seed: int = 42, +) -> List[str]: + """Load training text samples from a dataset. + + Tries to use OpenHermes-2.5, falls back to a simple math/code/text mix. + """ + try: + from datasets import load_dataset + logger.info("Loading training data from teknium/OpenHermes-2.5...") + ds = load_dataset("teknium/OpenHermes-2.5", split="train", streaming=True) + texts = [] + for item in ds: + if len(texts) >= num_samples: + break + # Extract conversation text + convos = item.get("conversations", []) + text = " ".join(c.get("value", "") for c in convos) + if len(text) > 50: # Skip very short entries + texts.append(text[:max_seq_len * 4]) # Rough char limit + logger.info(f"Loaded {len(texts)} training samples") + return texts + except Exception as e: + logger.warning(f"Could not load OpenHermes: {e}. Using fallback data.") + # Fallback: generate simple diverse prompts + import random + rng = random.Random(seed) + templates = [ + "Solve the following math problem step by step: {a} + {b} * {c} = ?", + "Write a function that computes the {n}th Fibonacci number.", + "Explain the concept of {topic} in simple terms.", + "The quick brown fox jumps over the lazy dog. {extra}", + "In {year}, {person} made a significant discovery about {topic}.", + ] + topics = ["recursion", "neural networks", "gravity", "evolution", "democracy", + "photosynthesis", "quantum mechanics", "machine learning", "databases"] + texts = [] + for i in range(num_samples): + tmpl = rng.choice(templates) + text = tmpl.format( + a=rng.randint(1, 1000), b=rng.randint(1, 100), c=rng.randint(1, 50), + n=rng.randint(5, 50), + topic=rng.choice(topics), + extra=" ".join(rng.choices(topics, k=3)), + year=rng.randint(1900, 2025), + person=f"Dr. {rng.choice(['Smith', 'Chen', 'Patel', 'Garcia', 'Kim'])}", + ) + texts.append(text) + return texts + + +def _tokenize_batch( + texts: List[str], + tokenizer: Any, + max_seq_len: int, + device: str, +) -> Any: + """Tokenize a batch of texts.""" + torch = _require_torch() + encoded = tokenizer( + texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_seq_len, + add_special_tokens=True, + ) + return {k: v.to(device) for k, v in encoded.items()} + + +def train_projector( + source_model: Any, + target_model: Any, + source_tokenizer: Any, + target_tokenizer: Any, + device: str = "cuda", + config: Optional[TrainConfig] = None, + progress_callback: Optional[Any] = None, +) -> "AVPMap": + """Train a per-layer cross-model projector. + + Both models are frozen. Only the lightweight projector trains. + Uses MSE loss between projected source hidden states and target + reference hidden states at each layer. + + Args: + source_model: Source HuggingFace model (frozen). + target_model: Target HuggingFace model (frozen). + source_tokenizer: Source tokenizer. + target_tokenizer: Target tokenizer (used for training data). + device: Training device. + config: Training configuration. + progress_callback: Optional callback(step, total_steps, loss) for progress. + + Returns: + AVPMap with method=TRAINED and per-layer projection data. + """ + torch = _require_torch() + import torch.nn.functional as F + from ..handshake import compute_model_hash, extract_model_identity + + if config is None: + config = TrainConfig() + + # Set seed + torch.manual_seed(config.seed) + + # Get model info + src_identity = extract_model_identity(source_model, source_tokenizer) + tgt_identity = extract_model_identity(target_model, target_tokenizer) + src_hash = compute_model_hash(source_model.config.to_dict()) + tgt_hash = compute_model_hash(target_model.config.to_dict()) + + source_dim = src_identity.hidden_dim + target_dim = tgt_identity.hidden_dim + target_num_layers = tgt_identity.num_layers + + logger.info( + "Training projector: %s (%dd) -> %s (%dd, %d layers)", + src_identity.model_id, source_dim, + tgt_identity.model_id, target_dim, target_num_layers, + ) + + # Create projector + projector = LayerProjector(source_dim, target_dim, target_num_layers, config) + projector.to(device) + projector.train() + + # Freeze both models + source_model.eval() + target_model.eval() + for p in source_model.parameters(): + p.requires_grad_(False) + for p in target_model.parameters(): + p.requires_grad_(False) + + # Load training data + texts = _load_training_data( + target_tokenizer, config.num_samples, config.max_seq_len, config.seed + ) + + # Optimizer + optimizer = torch.optim.AdamW( + projector.parameters(), + lr=config.learning_rate, + weight_decay=0.01, + ) + + # Training loop + num_batches = math.ceil(len(texts) / config.batch_size) + total_steps = num_batches * config.num_epochs + step = 0 + best_loss = float("inf") + + for epoch in range(config.num_epochs): + epoch_loss = 0.0 + epoch_steps = 0 + + for batch_start in range(0, len(texts), config.batch_size): + batch_texts = texts[batch_start:batch_start + config.batch_size] + + # Tokenize with target tokenizer (both models process same text) + try: + tgt_encoded = _tokenize_batch(batch_texts, target_tokenizer, config.max_seq_len, device) + src_encoded = _tokenize_batch(batch_texts, source_tokenizer, config.max_seq_len, device) + except Exception: + continue + + # Source forward pass -> extract hidden state + with torch.no_grad(): + src_out = source_model( + **src_encoded, + output_hidden_states=True, + return_dict=True, + ) + # Extract from extraction layer ratio + src_hidden_states = src_out.hidden_states + extract_idx = int(len(src_hidden_states) * config.extraction_layer_ratio) + extract_idx = min(extract_idx, len(src_hidden_states) - 1) + src_hidden = src_hidden_states[extract_idx][:, -1, :] # [B, D_src] + + # Project source hidden state through all layers + projections = projector.forward(src_hidden.float(), return_gate_tensors=True) + + # Build per-layer projections for hooks + layer_proj_for_hooks = [] + for proj_h, gate in projections: + gate_val = gate.item() + if gate_val < 0.001: + layer_proj_for_hooks.append(None) + else: + layer_proj_for_hooks.append((proj_h, gate)) + + loss = torch.tensor(0.0, device=device, requires_grad=True) + + # Reference forward pass (unhooked, no grad) for MSE auxiliary + ref_hidden_states = None + if config.mse_aux_weight > 0: + with torch.no_grad(): + ref_out = target_model( + **tgt_encoded, + output_hidden_states=True, + return_dict=True, + ) + ref_hidden_states = ref_out.hidden_states + + # Primary loss: NTP through target model with injected projections + if config.use_ntp_loss: + from .trained_hooks import trained_multi_layer_hook + tgt_input_ids = tgt_encoded["input_ids"] + labels = tgt_input_ids[:, 1:].contiguous() # shift right + + with trained_multi_layer_hook(target_model, layer_proj_for_hooks): + tgt_out = target_model( + **tgt_encoded, + output_hidden_states=False, + return_dict=True, + ) + logits = tgt_out.logits[:, :-1, :].contiguous() # [B, seq-1, vocab] + ntp_loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + labels.view(-1), + ignore_index=target_tokenizer.pad_token_id or -100, + ) + loss = loss + ntp_loss + elif config.mse_aux_weight > 0: + pass # ref_hidden_states already computed above + else: + # Neither NTP nor MSE — nothing to train on + with torch.no_grad(): + ref_out = target_model( + **tgt_encoded, + output_hidden_states=True, + return_dict=True, + ) + ref_hidden_states = ref_out.hidden_states + + # Auxiliary loss: MSE between projected and unhooked reference hidden states + if config.mse_aux_weight > 0 and ref_hidden_states is not None: + mse_loss = torch.tensor(0.0, device=device, requires_grad=True) + for layer_idx, (proj_h, gate) in enumerate(projections): + gate_val = gate.item() + if gate_val < 0.001: + continue + ref_idx = min(layer_idx + 1, len(ref_hidden_states) - 1) + ref_h = ref_hidden_states[ref_idx][:, -1, :].float().detach() + layer_loss = F.mse_loss(proj_h, ref_h) + mse_loss = mse_loss + gate * layer_loss + loss = loss + config.mse_aux_weight * mse_loss + + # Gate sparsity regularization (encourage most gates to be ~0) + gate_sum = sum( + torch.sigmoid(g) for g in projector.layer_gates + ) / target_num_layers + loss = loss + config.gate_reg_weight * gate_sum + + # Backward pass + optimizer.zero_grad() + loss.backward() + + # Gradient clipping + torch.nn.utils.clip_grad_norm_(projector.parameters(), 1.0) + + optimizer.step() + + step += 1 + loss_val = loss.item() + epoch_loss += loss_val + epoch_steps += 1 + + if step % 50 == 0: + active = projector.get_active_layers() + logger.info( + "Step %d/%d, loss=%.4f, active_layers=%d/%d", + step, total_steps, loss_val, len(active), target_num_layers, + ) + + if progress_callback is not None: + progress_callback(step, total_steps, loss_val) + + # Free intermediate tensors + del src_out, projections, loss + ref_hidden_states = None + if device == "cuda": + torch.cuda.empty_cache() + + avg_loss = epoch_loss / max(epoch_steps, 1) + logger.info("Epoch %d/%d complete, avg_loss=%.4f", epoch + 1, config.num_epochs, avg_loss) + + if avg_loss < best_loss: + best_loss = avg_loss + + # Export to AVPMap + projector.eval() + layer_weights, layer_biases, layer_gates = projector.export_weights() + + active = projector.get_active_layers() + logger.info( + "Training complete. Active layers: %d/%d (gates > 0.01): %s", + len(active), target_num_layers, active, + ) + + from .calibrate import AVPMap + from ..types import ProjectionMethod + + # Compute target norm from target model embeddings + tgt_embed = target_model.get_input_embeddings() + target_norm = tgt_embed.weight.float().norm(dim=-1).mean().detach().cpu() + + avp_map = AVPMap( + source_model_id=src_identity.model_id, + source_hash=src_hash, + source_dim=source_dim, + target_model_id=tgt_identity.model_id, + target_hash=tgt_hash, + target_dim=target_dim, + w_map=torch.zeros(1), # Placeholder (projections stored in layer_weights) + bias=None, + target_norm=target_norm, + method=ProjectionMethod.TRAINED, + anchor_count=config.num_samples, + validation_score=1.0 - best_loss, # Higher is better + layer_weights=layer_weights, + layer_biases=layer_biases, + layer_gates=layer_gates, + ) + + return avp_map diff --git a/src/avp/rosetta/trained_hooks.py b/src/avp/rosetta/trained_hooks.py new file mode 100644 index 0000000..a131bfa --- /dev/null +++ b/src/avp/rosetta/trained_hooks.py @@ -0,0 +1,106 @@ +"""Forward hooks for trained per-layer cross-model projection. + +Installs per-layer forward hooks that additively inject projected source +hidden states (scaled by learned gates) during the first forward pass (prefill). +Hooks fire once and are removed after the context manager exits. +""" + +import logging +from contextlib import contextmanager +from typing import Any, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +def _get_decoder_layers(model: Any): + """Get the list of decoder layers from a HuggingFace model.""" + inner = getattr(model, "model", None) + if inner is not None: + layers = getattr(inner, "layers", None) + if layers is not None: + return layers + + transformer = getattr(model, "transformer", None) + if transformer is not None: + h = getattr(transformer, "h", None) + if h is not None: + return h + + raise AttributeError( + f"Cannot find decoder layers in model {type(model).__name__}. " + "Expected model.model.layers or model.transformer.h" + ) + + +@contextmanager +def trained_multi_layer_hook( + model: Any, + layer_projections: List[Optional[Tuple[Any, float]]], +): + """Context manager that installs per-layer forward hooks for trained projection. + + Each hook adds the pre-computed projected hidden state (scaled by gate) + to the target layer's last-token hidden state. Hooks fire only on the + first forward pass (prefill), not during autoregressive generation. + + Args: + model: HuggingFace model to hook into. + layer_projections: List of length num_layers. Each entry is either: + - None (gate < threshold, skip this layer) + - (projected_hidden [1, D_tgt], gate_value float) + + Yields: + None. Hooks are active during the context. + """ + import torch + + layers = _get_decoder_layers(model) + handles = [] + fired_flags = [] + + for layer_idx, proj_data in enumerate(layer_projections): + if proj_data is None: + continue + if layer_idx >= len(layers): + break + + projected, gate = proj_data + target_layer = layers[layer_idx] + fired = [False] + fired_flags.append(fired) + + def make_hook(proj_h, g, f): + def hook_fn(module, input, output): + if f[0]: + return output + + f[0] = True + + if isinstance(output, tuple): + hidden = output[0] # [B, seq_len, D] + else: + hidden = output + + # Add projected source hidden state (scaled by gate) to last token + injection = proj_h.to(device=hidden.device, dtype=hidden.dtype) + if injection.dim() == 2: + injection = injection.squeeze(0) # [1, D] -> [D] + + modified = hidden.clone() + modified[:, -1, :] = modified[:, -1, :] + g * injection + + if isinstance(output, tuple): + return (modified,) + output[1:] + return modified + + return hook_fn + + hook = make_hook(projected, gate, fired) + handle = target_layer.register_forward_hook(hook) + handles.append(handle) + + try: + yield + finally: + for handle in handles: + handle.remove() diff --git a/src/avp/types.py b/src/avp/types.py index c156253..3ac747d 100644 --- a/src/avp/types.py +++ b/src/avp/types.py @@ -60,6 +60,7 @@ class ProjectionMethod(enum.Enum): PROCRUSTES = "procrustes" VOCAB_MEDIATED = "vocab_mediated" VOCAB_OVERLAP = "vocab_overlap" + TRAINED = "trained" class DataType(enum.IntEnum): diff --git a/tests/test_logit_guided.py b/tests/test_logit_guided.py new file mode 100644 index 0000000..834a327 --- /dev/null +++ b/tests/test_logit_guided.py @@ -0,0 +1,269 @@ +"""Tests for logit-guided cross-model decoding.""" + +import pytest +import torch +from conftest import requires_torch, requires_transformers + + +# --------------------------------------------------------------------------- +# Unit tests for compute_cross_model_logit_bias +# --------------------------------------------------------------------------- + +@requires_torch +class TestComputeLogitBias: + """Test bias computation from source hidden states.""" + + def _make_avp_map(self, method, src_indices=None, tgt_indices=None, + target_norm=None): + """Create a minimal AVPMap-like object for testing.""" + from avp.rosetta.calibrate import AVPMap + + return AVPMap( + source_model_id="source", + source_hash="src_hash", + source_dim=32, + target_model_id="target", + target_hash="tgt_hash", + target_dim=32, + w_map=torch.randn(32, 32), + bias=None, + target_norm=target_norm or torch.tensor(1.0), + method=method, + anchor_count=0, + validation_score=0.0, + src_indices=src_indices, + tgt_indices=tgt_indices, + ) + + def test_vocab_overlap_bias_shape(self): + """Bias tensor has correct shape matching target vocab.""" + from avp.rosetta.logit_guided import compute_cross_model_logit_bias + from avp.types import ProjectionMethod + + hidden = torch.randn(1, 32) + lm_head_w = torch.randn(100, 32) # source vocab=100 + target_vocab = 120 + + src_idx = torch.arange(50) # 50 shared tokens + tgt_idx = torch.arange(50) + + avp_map = self._make_avp_map( + ProjectionMethod.VOCAB_OVERLAP, + src_indices=src_idx, + tgt_indices=tgt_idx, + ) + + bias = compute_cross_model_logit_bias( + hidden, lm_head_w, avp_map, target_vocab, + ) + + assert bias.shape == (target_vocab,) + + def test_vocab_overlap_unmapped_tokens_zero(self): + """Unmapped tokens should have zero bias.""" + from avp.rosetta.logit_guided import compute_cross_model_logit_bias + from avp.types import ProjectionMethod + + hidden = torch.randn(1, 32) + lm_head_w = torch.randn(100, 32) + target_vocab = 120 + + src_idx = torch.arange(50) + tgt_idx = torch.arange(50) # only tokens 0-49 mapped + + avp_map = self._make_avp_map( + ProjectionMethod.VOCAB_OVERLAP, + src_indices=src_idx, + tgt_indices=tgt_idx, + ) + + bias = compute_cross_model_logit_bias( + hidden, lm_head_w, avp_map, target_vocab, + ) + + # Tokens 50-119 should be zero (before zero-mean adjustment) + # After zero-mean, mapped tokens are shifted but unmapped stay zero + unmapped_mask = torch.ones(target_vocab, dtype=torch.bool) + unmapped_mask[tgt_idx] = False + assert (bias[unmapped_mask] == 0.0).all() + + def test_vocab_overlap_bias_is_zero_mean(self): + """Mapped tokens should have zero-mean bias.""" + from avp.rosetta.logit_guided import compute_cross_model_logit_bias + from avp.types import ProjectionMethod + + hidden = torch.randn(1, 32) + lm_head_w = torch.randn(100, 32) + target_vocab = 120 + + src_idx = torch.arange(50) + tgt_idx = torch.arange(50) + + avp_map = self._make_avp_map( + ProjectionMethod.VOCAB_OVERLAP, + src_indices=src_idx, + tgt_indices=tgt_idx, + ) + + bias = compute_cross_model_logit_bias( + hidden, lm_head_w, avp_map, target_vocab, + ) + + mapped_bias = bias[tgt_idx] + assert abs(mapped_bias.mean().item()) < 1e-5 + + def test_vocab_mediated_bias(self): + """VOCAB_MEDIATED uses direct 1:1 mapping.""" + from avp.rosetta.logit_guided import compute_cross_model_logit_bias + from avp.types import ProjectionMethod + + hidden = torch.randn(1, 32) + lm_head_w = torch.randn(100, 32) + target_vocab = 100 # same size + + avp_map = self._make_avp_map(ProjectionMethod.VOCAB_MEDIATED) + + bias = compute_cross_model_logit_bias( + hidden, lm_head_w, avp_map, target_vocab, + ) + + assert bias.shape == (target_vocab,) + # All tokens should be mapped (non-zero after zero-mean) + assert (bias != 0.0).sum() > 0 + + def test_ridge_returns_zero_bias(self): + """RIDGE method has no token-level mapping — returns zero bias.""" + from avp.rosetta.logit_guided import compute_cross_model_logit_bias + from avp.types import ProjectionMethod + + hidden = torch.randn(1, 32) + lm_head_w = torch.randn(100, 32) + target_vocab = 120 + + avp_map = self._make_avp_map(ProjectionMethod.RIDGE) + + bias = compute_cross_model_logit_bias( + hidden, lm_head_w, avp_map, target_vocab, + ) + + assert (bias == 0.0).all() + + def test_1d_hidden_state(self): + """Should handle 1D hidden state (no batch dim).""" + from avp.rosetta.logit_guided import compute_cross_model_logit_bias + from avp.types import ProjectionMethod + + hidden = torch.randn(32) # no batch dim + lm_head_w = torch.randn(100, 32) + target_vocab = 120 + + src_idx = torch.arange(50) + tgt_idx = torch.arange(50) + + avp_map = self._make_avp_map( + ProjectionMethod.VOCAB_OVERLAP, + src_indices=src_idx, + tgt_indices=tgt_idx, + ) + + bias = compute_cross_model_logit_bias( + hidden, lm_head_w, avp_map, target_vocab, + ) + + assert bias.shape == (target_vocab,) + + +# --------------------------------------------------------------------------- +# Unit tests for CrossModelLogitBias processor +# --------------------------------------------------------------------------- + +@requires_torch +class TestCrossModelLogitBias: + """Test the LogitsProcessor behavior.""" + + def test_applies_bias(self): + """Bias should modify scores.""" + from avp.rosetta.logit_guided import CrossModelLogitBias + + bias = torch.tensor([0.0, 1.0, -1.0, 0.5]) + processor = CrossModelLogitBias(bias, alpha=1.0, confidence_threshold=1.0) + + input_ids = torch.tensor([[1, 2, 3]]) + scores = torch.tensor([[0.0, 0.0, 0.0, 0.0]]) + + result = processor(input_ids, scores) + + # With uniform scores and confidence_threshold=1.0 (always apply), + # result should be scores + alpha * bias + expected = scores + bias + assert torch.allclose(result, expected, atol=1e-5) + + def test_confidence_gating_suppresses_bias(self): + """When target is confident, bias should be suppressed.""" + from avp.rosetta.logit_guided import CrossModelLogitBias + + bias = torch.tensor([0.0, 10.0, -10.0, 5.0]) + processor = CrossModelLogitBias(bias, alpha=1.0, confidence_threshold=0.5) + + input_ids = torch.tensor([[1]]) + # Very confident scores — token 0 dominates + scores = torch.tensor([[100.0, -100.0, -100.0, -100.0]]) + + result = processor(input_ids, scores) + + # Max prob ≈ 1.0 >> threshold 0.5, so bias should be suppressed + assert torch.allclose(result, scores, atol=1e-5) + + def test_confidence_gating_allows_bias_when_uncertain(self): + """When target is uncertain, bias should be applied.""" + from avp.rosetta.logit_guided import CrossModelLogitBias + + bias = torch.tensor([0.0, 1.0, -1.0, 0.5]) + processor = CrossModelLogitBias(bias, alpha=1.0, confidence_threshold=0.9) + + input_ids = torch.tensor([[1]]) + # Uniform scores — max_prob = 0.25 < 0.9 + scores = torch.zeros(1, 4) + + result = processor(input_ids, scores) + + # Should apply bias since target is uncertain + expected = scores + bias + assert torch.allclose(result, expected, atol=1e-5) + + def test_alpha_scaling(self): + """Alpha should scale the bias.""" + from avp.rosetta.logit_guided import CrossModelLogitBias + + bias = torch.tensor([1.0, 2.0, 3.0]) + processor = CrossModelLogitBias(bias, alpha=0.5, confidence_threshold=1.0) + + input_ids = torch.tensor([[1]]) + scores = torch.zeros(1, 3) + + result = processor(input_ids, scores) + expected = torch.tensor([[0.5, 1.0, 1.5]]) + assert torch.allclose(result, expected, atol=1e-5) + + def test_batch_confidence_gating(self): + """Per-batch-element gating should work.""" + from avp.rosetta.logit_guided import CrossModelLogitBias + + bias = torch.tensor([0.0, 1.0, -1.0]) + processor = CrossModelLogitBias(bias, alpha=1.0, confidence_threshold=0.5) + + input_ids = torch.tensor([[1], [2]]) + # Batch elem 0: confident (token 0 dominates) + # Batch elem 1: uncertain (uniform) + scores = torch.tensor([ + [100.0, -100.0, -100.0], + [0.0, 0.0, 0.0], + ]) + + result = processor(input_ids, scores) + + # Elem 0: bias suppressed + assert torch.allclose(result[0], scores[0], atol=1e-5) + # Elem 1: bias applied + expected_1 = scores[1] + bias + assert torch.allclose(result[1], expected_1, atol=1e-5) diff --git a/tests/test_mid_layer.py b/tests/test_mid_layer.py new file mode 100644 index 0000000..710cc9a --- /dev/null +++ b/tests/test_mid_layer.py @@ -0,0 +1,200 @@ +"""Tests for mid-layer injection cross-model transfer.""" + +import pytest +import torch + +from avp.rosetta.mid_layer import ( + DEFAULT_DEPTH_RATIO, + compute_extraction_layer, + compute_injection_layer, + extract_mid_layer_hidden, + mid_layer_injection_hook, + _get_decoder_layers, +) + + +class TestLayerComputation: + """Tests for extraction/injection layer computation.""" + + def test_default_depth_ratio(self): + assert DEFAULT_DEPTH_RATIO == 0.75 + + def test_extraction_layer_28_layers(self): + """Qwen 7B has 28 layers -> extract from layer 21.""" + layer = compute_extraction_layer(28) + assert layer == 21 # int(28 * 0.75) + + def test_extraction_layer_32_layers(self): + """Llama 3B has 32 layers -> extract from layer 24.""" + layer = compute_extraction_layer(32) + assert layer == 24 + + def test_injection_layer_proportional(self): + """28 source layers, 32 target layers -> inject at 24.""" + layer = compute_injection_layer(32) + assert layer == 24 + + def test_custom_depth_ratio(self): + layer = compute_extraction_layer(28, depth_ratio=0.5) + assert layer == 14 + + def test_depth_ratio_zero(self): + layer = compute_extraction_layer(28, depth_ratio=0.0) + assert layer == 0 + + def test_depth_ratio_one(self): + """Ratio 1.0 should clamp to last layer.""" + layer = compute_extraction_layer(28, depth_ratio=1.0) + assert layer == 27 # min(28, 27) + + def test_small_model(self): + layer = compute_extraction_layer(4, depth_ratio=0.75) + assert layer == 3 + + +class TestExtractMidLayerHidden: + """Tests for mid-layer hidden state extraction.""" + + def test_extracts_correct_layer(self): + """Verify extraction returns the right layer's last token.""" + # Simulate hidden_states: tuple of (num_layers + 1) tensors + # Index 0 = embedding, index i+1 = layer i output + num_layers = 10 + batch_size = 1 + seq_len = 5 + hidden_dim = 32 + + hidden_states = tuple( + torch.randn(batch_size, seq_len, hidden_dim) * (i + 1) + for i in range(num_layers + 1) + ) + + class MockOutputs: + pass + + outputs = MockOutputs() + outputs.hidden_states = hidden_states + + extraction_layer = 7 + result = extract_mid_layer_hidden(outputs, extraction_layer) + + assert result.shape == (batch_size, hidden_dim) + # Should match hidden_states[extraction_layer + 1][:, -1, :] + expected = hidden_states[extraction_layer + 1][:, -1, :] + assert torch.allclose(result, expected) + + def test_clamps_to_valid_range(self): + """Layer beyond range should clamp.""" + hidden_states = tuple( + torch.randn(1, 5, 16) for _ in range(6) # 5 layers + embedding + ) + + class MockOutputs: + pass + + outputs = MockOutputs() + outputs.hidden_states = hidden_states + + # Request layer 10 but only 5 layers exist + result = extract_mid_layer_hidden(outputs, 10) + assert result.shape == (1, 16) + + +class TestMidLayerInjectionHook: + """Tests for the forward hook context manager.""" + + def test_hook_modifies_output(self): + """Test that the hook replaces the last token's hidden state.""" + # Simple linear model to test hooking + layer = torch.nn.Linear(16, 16, bias=False) + # Set to identity so we can verify the hook's effect + with torch.no_grad(): + layer.weight.copy_(torch.eye(16)) + + class SimpleModel: + def __init__(self): + self.model = type("inner", (), {"layers": torch.nn.ModuleList([layer])})() + + model = SimpleModel() + projected = torch.ones(1, 16) * 42.0 # Distinctive value + + with mid_layer_injection_hook(model, 0, projected): + # Simulate a forward pass through the hooked layer + input_tensor = torch.randn(1, 5, 16) + output = layer(input_tensor) + # output should have last position replaced + + # After context exit, hook should be removed + assert len(layer._forward_hooks) == 0 + + def test_hook_fires_only_once(self): + """The hook should fire only on the first forward pass.""" + layer = torch.nn.Linear(16, 16, bias=False) + with torch.no_grad(): + layer.weight.copy_(torch.eye(16)) + + class SimpleModel: + def __init__(self): + self.model = type("inner", (), {"layers": torch.nn.ModuleList([layer])})() + + model = SimpleModel() + projected = torch.ones(1, 16) * 42.0 + + with mid_layer_injection_hook(model, 0, projected): + # First forward pass — hook should fire + input1 = torch.randn(1, 5, 16) + out1 = layer(input1) + + # Second forward pass — hook should NOT fire + input2 = torch.randn(1, 3, 16) + out2 = layer(input2) + + # First pass: last token should be replaced + # Second pass: should be unmodified + # (The hook returns the modified output only on first call) + + assert len(layer._forward_hooks) == 0 + + def test_hook_cleanup_on_exception(self): + """Hook should be removed even if an exception occurs.""" + layer = torch.nn.Linear(16, 16, bias=False) + + class SimpleModel: + def __init__(self): + self.model = type("inner", (), {"layers": torch.nn.ModuleList([layer])})() + + model = SimpleModel() + projected = torch.ones(1, 16) + + with pytest.raises(RuntimeError): + with mid_layer_injection_hook(model, 0, projected): + raise RuntimeError("test error") + + # Hook should still be removed + assert len(layer._forward_hooks) == 0 + + +class TestGetDecoderLayers: + """Tests for model layer discovery.""" + + def test_llama_style_model(self): + """Models with model.model.layers (Llama, Qwen).""" + layers = torch.nn.ModuleList([torch.nn.Linear(16, 16) for _ in range(4)]) + inner = type("inner", (), {"layers": layers})() + model = type("model", (), {"model": inner})() + result = _get_decoder_layers(model) + assert result is layers + + def test_gpt2_style_model(self): + """Models with model.transformer.h (GPT-2).""" + h = torch.nn.ModuleList([torch.nn.Linear(16, 16) for _ in range(4)]) + transformer = type("transformer", (), {"h": h})() + model = type("model", (), {"model": None, "transformer": transformer})() + result = _get_decoder_layers(model) + assert result is h + + def test_unknown_model_raises(self): + """Models without known layer paths should raise.""" + model = type("model", (), {"model": None, "transformer": None})() + with pytest.raises(AttributeError, match="Cannot find decoder layers"): + _get_decoder_layers(model) diff --git a/tests/test_smart_routing.py b/tests/test_smart_routing.py new file mode 100644 index 0000000..a00a4e8 --- /dev/null +++ b/tests/test_smart_routing.py @@ -0,0 +1,248 @@ +"""Tests for smart routing / task classification in quality gate.""" + +import pytest + +from avp.rosetta.quality import ( + TaskClassification, + TransferQualityConfig, + TransferQualityResult, + assess_transfer, + classify_task, +) + + +class TestClassifyTask: + """Tests for classify_task() prompt classification.""" + + def test_math_problem_classified_as_structured(self): + prompt = "Solve: 24 * 17 + 3 = ?" + result = classify_task(prompt, prompt_tokens=50) + assert result.task_type == "structured" + assert result.score > 0 + + def test_code_problem_classified_as_structured(self): + prompt = "def fibonacci(n):\n # implement this function\n return result" + result = classify_task(prompt, prompt_tokens=30) + assert result.task_type == "structured" + assert result.score > 0 + + def test_comprehension_classified_as_comprehension(self): + prompt = ( + "The industrial revolution began in Britain in the late 18th century. " + "It was a period of great change in manufacturing, mining, and transport.\n\n" + "The textile industry was the first to adopt new methods of production.\n\n" + "Steam power was central to the changes.\n\n" + "Who started the industrial revolution?" + ) + result = classify_task(prompt, prompt_tokens=800) + assert result.task_type == "comprehension" + assert result.score < 0 + + def test_high_digit_density_is_structured(self): + prompt = "Calculate 123 + 456 * 789 - 321 / 654 = ?" + result = classify_task(prompt, prompt_tokens=20) + assert result.features["digit_density"] > 0 + assert result.task_type == "structured" + + def test_no_digits_no_digit_bonus(self): + prompt = "Tell me about the history of France" + result = classify_task(prompt, prompt_tokens=10) + assert result.features["digit_density"] == 0 + + def test_long_prompt_penalized(self): + prompt = "Some text" + result = classify_task(prompt, prompt_tokens=600) + assert result.features["token_count"] == -2 + + def test_short_prompt_rewarded(self): + prompt = "Solve: 2 + 2" + result = classify_task(prompt, prompt_tokens=50) + assert result.features["token_count"] == 1 + + def test_math_markers_detected(self): + prompt = r"Find \frac{x}{y} where \sqrt{x} = 5" + result = classify_task(prompt, prompt_tokens=20) + assert result.features["markers"] > 0 + + def test_code_markers_detected(self): + prompt = "```python\ndef foo():\n return 42\n```" + result = classify_task(prompt, prompt_tokens=20) + assert result.features["markers"] > 0 + + def test_comprehension_questions_detected(self): + prompt = "Based on the passage above, who was the first president?" + result = classify_task(prompt, prompt_tokens=100) + assert result.features["question_type"] < 0 + + def test_structured_starters_detected(self): + prompt = "Solve the following equation: 3x + 5 = 20" + result = classify_task(prompt, prompt_tokens=50) + assert result.features["question_type"] > 0 + + def test_multi_paragraph_penalized(self): + prompt = "Paragraph one.\n\nParagraph two.\n\nParagraph three.\n\nQuestion?" + result = classify_task(prompt, prompt_tokens=100) + assert result.features["paragraphs"] < 0 + + def test_returns_task_classification_dataclass(self): + result = classify_task("test prompt", prompt_tokens=50) + assert isinstance(result, TaskClassification) + assert result.task_type in ("structured", "comprehension") + assert isinstance(result.score, int) + assert isinstance(result.features, dict) + + def test_gsm8k_style_prompt(self): + """GSM8K-style math prompt should be structured.""" + prompt = ( + "Janet's ducks lay 16 eggs per day. She eats three for breakfast " + "every morning and bakes muffins for her friends every day with four. " + "She sells the remainder at the farmers' market daily for $2 per " + "fresh duck egg. How much in dollars does she make every day at the " + "farmers' market?" + ) + result = classify_task(prompt, prompt_tokens=80) + assert result.task_type == "structured" + + def test_hotpotqa_style_prompt(self): + """HotpotQA-style multi-paragraph QA should be comprehension.""" + prompt = ( + "Walter Elias Disney was an American entrepreneur, animator, voice " + "actor and film producer. A pioneer of the American animation " + "industry, he introduced several developments in the production of " + "cartoons.\n\n" + "As a film producer, Disney holds the record for most Academy Awards " + "earned by an individual, having won 22 Oscars from 59 nominations.\n\n" + "He was presented with two Golden Globe Special Achievement Awards " + "and an Emmy Award, among other honors.\n\n" + "Several of his films are included in the National Film Registry by " + "the Library of Congress.\n\n" + "Based on the passage, how many Academy Awards did Disney win?" + ) + result = classify_task(prompt, prompt_tokens=1200) + assert result.task_type == "comprehension" + + def test_empty_prompt(self): + result = classify_task("", prompt_tokens=0) + assert isinstance(result, TaskClassification) + + def test_zero_tokens_no_token_bonus(self): + result = classify_task("test", prompt_tokens=0) + assert result.features["token_count"] == 0 + + +class TestAssessTransferWithPromptText: + """Tests for enhanced assess_transfer() with prompt text.""" + + def test_backward_compatible_no_prompt_text(self): + """Without prompt_text, behaves exactly like v1.""" + result = assess_transfer(prompt_tokens=200) + assert result.recommend_latent is True + assert result.task_classification is None + + def test_backward_compatible_over_limit(self): + result = assess_transfer(prompt_tokens=400) + assert result.recommend_latent is False + assert result.task_classification is None + + def test_structured_prompt_recommended_latent(self): + result = assess_transfer( + prompt_tokens=100, + prompt_text="Solve: 24 * 17 + 3 = ?", + ) + assert result.recommend_latent is True + assert result.task_classification is not None + assert result.task_classification.task_type == "structured" + + def test_comprehension_prompt_blocks_latent(self): + """Strong comprehension signal blocks latent even under token limit.""" + prompt = ( + "The Mona Lisa is a half-length portrait painting by Italian " + "artist Leonardo da Vinci.\n\n" + "It has been described as the best known painting in the world.\n\n" + "The painting is thought to be a portrait of Lisa Gherardini.\n\n" + "According to the passage, who painted the Mona Lisa?" + ) + result = assess_transfer(prompt_tokens=200, prompt_text=prompt) + assert result.recommend_latent is False + assert result.task_classification is not None + assert result.task_classification.task_type == "comprehension" + + def test_structured_extends_token_limit(self): + """Strong structured signal allows prompts up to 1.5x limit.""" + prompt = ( + "```python\n" + "def calculate(x, y):\n" + " result = x * y + 100 / 50 - 25\n" + " return result\n" + "```\n" + "Solve: compute calculate(10, 20) step by step. " + "Find the value of 10 * 20 + 100 / 50 - 25" + ) + # 350 tokens > 300 limit, but strong structured signals + result = assess_transfer(prompt_tokens=350, prompt_text=prompt) + assert result.recommend_latent is True + + def test_structured_still_blocked_beyond_extended_limit(self): + """Even structured, blocked if way over limit (>450 for default 300).""" + prompt = "Solve: 2 + 2 = ?" + result = assess_transfer(prompt_tokens=500, prompt_text=prompt) + assert result.recommend_latent is False + + def test_task_classification_disabled(self): + config = TransferQualityConfig(use_task_classification=False) + result = assess_transfer( + prompt_tokens=200, + prompt_text="Who painted the Mona Lisa?", + config=config, + ) + assert result.task_classification is None + assert result.recommend_latent is True # under token limit + + def test_result_includes_classification(self): + result = assess_transfer( + prompt_tokens=100, + prompt_text="Solve: 2 + 2", + ) + assert isinstance(result, TransferQualityResult) + assert result.task_classification is not None + assert isinstance(result.task_classification, TaskClassification) + + def test_effective_rank_still_works_with_prompt_text(self): + """Effective rank check still works alongside task classification.""" + import torch + + config = TransferQualityConfig(check_effective_rank=True) + # Identity-like matrix = low effective rank + hidden = torch.eye(10, 64) + result = assess_transfer( + prompt_tokens=100, + prompt_text="Solve: 2 + 2", + hidden_states=hidden, + config=config, + ) + assert isinstance(result, TransferQualityResult) + + +class TestAssessTransferBackwardCompat: + """Ensure all existing behavior is preserved.""" + + def test_short_prompt_recommends_latent(self): + result = assess_transfer(prompt_tokens=200) + assert result.recommend_latent is True + + def test_long_prompt_recommends_text(self): + result = assess_transfer(prompt_tokens=1500) + assert result.recommend_latent is False + + def test_boundary_at_300(self): + assert assess_transfer(prompt_tokens=300).recommend_latent is True + assert assess_transfer(prompt_tokens=301).recommend_latent is False + + def test_custom_config_threshold(self): + config = TransferQualityConfig(max_prompt_tokens=100) + assert assess_transfer(prompt_tokens=100, config=config).recommend_latent is True + assert assess_transfer(prompt_tokens=101, config=config).recommend_latent is False + + def test_zero_tokens(self): + result = assess_transfer(prompt_tokens=0) + assert result.recommend_latent is True diff --git a/tests/test_trained_projector.py b/tests/test_trained_projector.py new file mode 100644 index 0000000..5f495d5 --- /dev/null +++ b/tests/test_trained_projector.py @@ -0,0 +1,341 @@ +"""Tests for the trained cross-model projector (C2C). + +Tests LayerProjector, TrainConfig, trained_hooks, and registry serialization +of trained AVPMap fields. +""" + +import pytest +import torch +import torch.nn as nn + +from avp.rosetta.train import LayerProjector, TrainConfig +from avp.rosetta.trained_hooks import trained_multi_layer_hook, _get_decoder_layers +from avp.rosetta.calibrate import AVPMap +from avp.types import ProjectionMethod + + +# --- TrainConfig tests --- + + +class TestTrainConfig: + def test_defaults(self): + config = TrainConfig() + assert config.num_samples == 5000 + assert config.batch_size == 4 + assert config.learning_rate == 1e-4 + assert config.num_epochs == 2 + assert config.gate_reg_weight == 0.01 + assert config.gate_init == -3.0 + assert config.max_seq_len == 256 + assert config.warmup_steps == 100 + assert config.seed == 42 + assert config.mse_aux_weight == 0.1 + assert config.use_ntp_loss is False + + def test_custom_values(self): + config = TrainConfig(num_samples=100, batch_size=8, learning_rate=1e-3) + assert config.num_samples == 100 + assert config.batch_size == 8 + assert config.learning_rate == 1e-3 + + +# --- LayerProjector tests --- + + +class TestLayerProjector: + def test_init_shapes(self): + proj = LayerProjector(source_dim=128, target_dim=64, num_layers=4) + assert len(proj.layer_projections) == 4 + assert len(proj.layer_gates) == 4 + # Each projection: [D_src] -> [D_tgt] + assert proj.layer_projections[0].in_features == 128 + assert proj.layer_projections[0].out_features == 64 + + def test_gate_init_near_zero(self): + """Gates should be initialized near zero (sigmoid(-3) ~ 0.05).""" + proj = LayerProjector(source_dim=32, target_dim=32, num_layers=8) + for gate_logit in proj.layer_gates: + gate_val = torch.sigmoid(gate_logit).item() + assert gate_val < 0.1, f"Gate initialized too high: {gate_val}" + + def test_forward_output_shapes(self): + proj = LayerProjector(source_dim=64, target_dim=32, num_layers=3) + src = torch.randn(2, 64) # batch=2 + results = proj.forward(src) + assert len(results) == 3 + for projected, gate in results: + assert projected.shape == (2, 32) + assert 0 <= gate <= 1 + + def test_get_active_layers_initial(self): + """Initially, all gates near zero → no active layers at default threshold.""" + proj = LayerProjector(source_dim=32, target_dim=32, num_layers=8) + active = proj.get_active_layers(threshold=0.1) + assert len(active) == 0 + + def test_get_active_layers_after_setting_gate(self): + proj = LayerProjector(source_dim=32, target_dim=32, num_layers=4) + # Force gate 1 and 3 open + with torch.no_grad(): + proj.layer_gates[1].fill_(5.0) # sigmoid(5) ~ 0.993 + proj.layer_gates[3].fill_(3.0) # sigmoid(3) ~ 0.953 + active = proj.get_active_layers(threshold=0.5) + assert active == [1, 3] + + def test_export_weights(self): + proj = LayerProjector(source_dim=64, target_dim=32, num_layers=3) + weights, biases, gates = proj.export_weights() + assert len(weights) == 3 + assert len(biases) == 3 + assert len(gates) == 3 + for w in weights: + assert w.shape == (32, 64) # nn.Linear stores [out, in] + for b in biases: + assert b.shape == (32,) + for g in gates: + assert isinstance(g, float) + assert 0 <= g <= 1 + + def test_parameters_count(self): + proj = LayerProjector(source_dim=64, target_dim=32, num_layers=3) + params = list(proj.parameters()) + # 3 layers × (weight + bias) + 3 gates = 9 parameters + assert len(params) == 9 + + def test_to_device(self): + proj = LayerProjector(source_dim=32, target_dim=16, num_layers=2) + proj.to("cpu") # should work without error + # Verify parameters are on cpu + for p in proj.parameters(): + assert str(p.device) == "cpu" + + +# --- Trained hooks tests --- + + +class FakeDecoderLayer(nn.Module): + """Minimal decoder layer for hook testing.""" + def __init__(self, dim): + super().__init__() + self.linear = nn.Linear(dim, dim) + + def forward(self, x): + return (self.linear(x),) + + +class FakeModel(nn.Module): + """Minimal model with decoder layers for hook testing.""" + def __init__(self, dim, num_layers): + super().__init__() + self.model = nn.Module() + self.model.layers = nn.ModuleList([ + FakeDecoderLayer(dim) for _ in range(num_layers) + ]) + + +class TestTrainedMultiLayerHook: + def test_hook_adds_projection(self): + dim = 16 + model = FakeModel(dim=dim, num_layers=4) + model.eval() + + projected = torch.ones(1, dim) * 10.0 + gate = 0.5 + + layer_projections = [None, (projected, gate), None, None] + + x = torch.randn(1, 5, dim) + + with trained_multi_layer_hook(model, layer_projections): + # Forward through layer 1 (the hooked layer) + layers = _get_decoder_layers(model) + output = layers[1](x) + + # The hook should have modified the last token + modified_hidden = output[0] + # Verify last token was modified (added gate * projection) + # Original output + 0.5 * 10.0 = original + 5.0 + with torch.no_grad(): + original_output = model.model.layers[1](x) + original_last = original_output[0][:, -1, :] + modified_last = modified_hidden[:, -1, :] + diff = (modified_last - original_last).abs().mean().item() + assert diff > 1.0, f"Hook didn't modify output enough: diff={diff}" + + def test_hook_fires_once(self): + dim = 8 + model = FakeModel(dim=dim, num_layers=2) + model.eval() + + projected = torch.ones(1, dim) * 100.0 + layer_projections = [(projected, 1.0), None] + + x = torch.randn(1, 3, dim) + + with trained_multi_layer_hook(model, layer_projections): + layers = _get_decoder_layers(model) + # First forward — hook fires + out1 = layers[0](x) + # Second forward — hook should NOT fire + out2 = layers[0](x) + + # out1 should differ from out2 (hook only fires once) + # Actually both outputs are from the same input x, but the first + # has the injection and the second doesn't + last1 = out1[0][:, -1, :].detach() + last2 = out2[0][:, -1, :].detach() + diff = (last1 - last2).abs().mean().item() + assert diff > 1.0, f"Hook fired more than once or didn't fire: diff={diff}" + + def test_hook_cleanup(self): + dim = 8 + model = FakeModel(dim=dim, num_layers=2) + + projected = torch.ones(1, dim) + layer_projections = [(projected, 0.5), (projected, 0.3)] + + with trained_multi_layer_hook(model, layer_projections): + pass + + # Verify hooks are removed + layers = _get_decoder_layers(model) + for layer in layers: + assert len(layer._forward_hooks) == 0 + + def test_skip_none_layers(self): + dim = 8 + model = FakeModel(dim=dim, num_layers=4) + + # Only layer 2 active + layer_projections = [None, None, (torch.ones(1, dim), 0.5), None] + + with trained_multi_layer_hook(model, layer_projections): + layers = _get_decoder_layers(model) + # Layers 0, 1, 3 should have no hooks + assert len(layers[0]._forward_hooks) == 0 + assert len(layers[1]._forward_hooks) == 0 + assert len(layers[2]._forward_hooks) == 1 + assert len(layers[3]._forward_hooks) == 0 + + def test_exception_cleanup(self): + dim = 8 + model = FakeModel(dim=dim, num_layers=2) + + projected = torch.ones(1, dim) + layer_projections = [(projected, 0.5), (projected, 0.3)] + + with pytest.raises(RuntimeError): + with trained_multi_layer_hook(model, layer_projections): + raise RuntimeError("test error") + + # Hooks should still be cleaned up + layers = _get_decoder_layers(model) + for layer in layers: + assert len(layer._forward_hooks) == 0 + + +# --- Registry serialization tests --- + + +class TestRegistrySerialization: + def test_save_load_trained_map(self, tmp_path): + """Round-trip save/load of AVPMap with trained projection fields.""" + from avp.rosetta.registry import save_map, load_map + + num_layers = 3 + src_dim = 64 + tgt_dim = 32 + + avp_map = AVPMap( + source_model_id="test/source", + source_hash="abc123" * 8, + source_dim=src_dim, + target_model_id="test/target", + target_hash="def456" * 8, + target_dim=tgt_dim, + w_map=torch.zeros(1), + bias=None, + target_norm=torch.tensor(1.0), + method=ProjectionMethod.TRAINED, + anchor_count=5000, + validation_score=0.85, + layer_weights=[torch.randn(tgt_dim, src_dim) for _ in range(num_layers)], + layer_biases=[torch.randn(tgt_dim) for _ in range(num_layers)], + layer_gates=[0.01, 0.95, 0.03], + ) + + path = save_map(avp_map, map_dir=tmp_path) + assert path.exists() + + loaded = load_map( + avp_map.source_hash, avp_map.target_hash, + device="cpu", map_dir=tmp_path, + ) + assert loaded is not None + assert loaded.method == ProjectionMethod.TRAINED + assert len(loaded.layer_weights) == num_layers + assert len(loaded.layer_biases) == num_layers + assert loaded.layer_gates == [0.01, 0.95, 0.03] + + # Verify tensor shapes preserved + for w in loaded.layer_weights: + assert w.shape == (tgt_dim, src_dim) + for b in loaded.layer_biases: + assert b.shape == (tgt_dim,) + + def test_save_load_non_trained_map_no_layer_fields(self, tmp_path): + """Non-trained maps should have None for layer fields.""" + from avp.rosetta.registry import save_map, load_map + + avp_map = AVPMap( + source_model_id="test/source", + source_hash="aaa111" * 8, + source_dim=64, + target_model_id="test/target", + target_hash="bbb222" * 8, + target_dim=32, + w_map=torch.randn(64, 32), + bias=None, + target_norm=torch.tensor(1.0), + method=ProjectionMethod.RIDGE, + anchor_count=50, + validation_score=0.5, + ) + + save_map(avp_map, map_dir=tmp_path) + loaded = load_map( + avp_map.source_hash, avp_map.target_hash, + device="cpu", map_dir=tmp_path, + ) + assert loaded is not None + assert loaded.layer_weights is None + assert loaded.layer_biases is None + assert loaded.layer_gates is None + + +# --- AVPMap TRAINED method test --- + + +class TestAVPMapTrained: + def test_trained_method_enum(self): + assert ProjectionMethod.TRAINED.value == "trained" + + def test_avpmap_with_trained_fields(self): + avp_map = AVPMap( + source_model_id="src", + source_hash="h1", + source_dim=128, + target_model_id="tgt", + target_hash="h2", + target_dim=64, + w_map=torch.zeros(1), + bias=None, + target_norm=torch.tensor(1.0), + method="trained", # string should convert + anchor_count=1000, + validation_score=0.9, + layer_weights=[torch.randn(64, 128)], + layer_biases=[torch.randn(64)], + layer_gates=[0.5], + ) + assert avp_map.method == ProjectionMethod.TRAINED