From 75d5935d18795d14c365a68c7616752116a11b2c Mon Sep 17 00:00:00 2001 From: Stanislav Date: Sat, 14 Mar 2026 00:47:53 +0000 Subject: [PATCH 01/14] Add logit-guided decoding for cross-model communication MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New approach: instead of compressing source model info into a single virtual token (rosetta), distribute the signal as additive logit bias during target generation. Source model's vocabulary distribution is mapped through vocab overlap to target vocabulary. Implementation: - New: rosetta/logit_guided.py — CrossModelLogitBias processor + bias computation - Modified: huggingface.py — cross_model_method="logit_guided" option in generate() - Modified: easy.py — pass through cross_model_method and logit_bias_alpha - New: pipeline_logit_guided.py — benchmark pipeline for GSM8K 2-agent - Modified: run_gsm8k_2agent.py — --mode logit_guided support - Modified: shared/generation.py — logits_processor kwarg support - New: test_logit_guided.py — 11 unit tests (bias shape, zero-mean, gating, scaling) Key features: - Confidence gating: skip bias when target is already confident (>0.8 max prob) - Zero-mean bias: doesn't shift distribution center, only nudges relative prefs - Alpha scaling: default 0.5 (conservative for cross-vocab mapping) - Falls back gracefully for RIDGE/PROCRUSTES (no token-level mapping) Co-Authored-By: Claude Opus 4.6 --- .../gsm8k_2agent/pipeline_logit_guided.py | 251 ++++++++++++++++ benchmarks/gsm8k_2agent/run_gsm8k_2agent.py | 64 ++++- benchmarks/shared/generation.py | 6 +- src/avp/connectors/huggingface.py | 105 +++++++ src/avp/easy.py | 4 + src/avp/rosetta/logit_guided.py | 134 +++++++++ tests/test_logit_guided.py | 269 ++++++++++++++++++ 7 files changed, 827 insertions(+), 6 deletions(-) create mode 100644 benchmarks/gsm8k_2agent/pipeline_logit_guided.py create mode 100644 src/avp/rosetta/logit_guided.py create mode 100644 tests/test_logit_guided.py diff --git a/benchmarks/gsm8k_2agent/pipeline_logit_guided.py b/benchmarks/gsm8k_2agent/pipeline_logit_guided.py new file mode 100644 index 0000000..9e7e256 --- /dev/null +++ b/benchmarks/gsm8k_2agent/pipeline_logit_guided.py @@ -0,0 +1,251 @@ +"""Logit-guided pipeline: 2-agent chain with cross-model logit bias. + +Researcher runs on model A (latent steps → extract hidden state → compute logits). +Solver runs on model B (generate with logit bias from mapped source distribution). + +Unlike rosetta (single virtual token in KV-cache), logit-guided distributes +the source signal across the target's entire autoregressive generation. +""" + +import time +from typing import Any, Dict, List + +import torch +from transformers import LogitsProcessorList + +from benchmarks.shared.generation import generate_text, render_prompt, tokenize_prompt +from benchmarks.shared.metrics import gpu_memory_tracker +from .agents import AGENTS, build_latent_prompt +from .evaluate import extract_gold, extract_gsm8k_answer, check_correct + + +def run_logit_guided_pipeline( + conn_a: Any, + model_a: Any, + tokenizer_a: Any, + identity_a: Any, + model_b: Any, + tokenizer_b: Any, + device: str, + avp_map: Any, + question: str, + gold_solution: str, + latent_steps: int = 10, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.95, + verbose: bool = False, + logit_bias_alpha: float = 0.5, + logit_bias_confidence_threshold: float = 0.8, +) -> Dict: + """Run the 2-agent cross-model pipeline with logit-guided decoding. + + Researcher (model A): latent steps → extract hidden state → compute logits + Solver (model B): generate with mapped logit bias (no KV-cache priming) + """ + from avp.rosetta.logit_guided import ( + CrossModelLogitBias, + compute_cross_model_logit_bias, + ) + + with gpu_memory_tracker(device) as mem: + t0 = time.perf_counter() + agent_traces: List[Dict] = [] + total_prompt_tokens = 0 + total_latent_steps = 0 + total_output_tokens = 0 + + researcher = AGENTS[0] + solver = AGENTS[1] + + # --- Agent 1: Researcher on model A (latent steps) --- + messages = build_latent_prompt(researcher.role, question) + prompt_text = render_prompt(tokenizer_a, messages) + input_ids, attention_mask = tokenize_prompt(tokenizer_a, prompt_text, device) + + agent_t0 = time.perf_counter() + prompt_tokens = int(input_ids.shape[-1]) + total_prompt_tokens += prompt_tokens + total_latent_steps += latent_steps + + # Run latent steps and collect hidden states + past_kv, hidden_states = conn_a.generate_latent_steps( + input_ids, latent_steps=latent_steps, attention_mask=attention_mask, + collect_hidden_states=True, + ) + + # Use last hidden state for logit computation + last_hidden = hidden_states[-1].unsqueeze(0) # [1, D_src] + + # Compute logit bias + bias_t0 = time.perf_counter() + + source_lm_head = model_a.get_output_embeddings() + if source_lm_head is None: + source_lm_head = getattr(model_a, "lm_head", None) + + target_vocab_size = model_b.config.vocab_size + bias = compute_cross_model_logit_bias( + source_hidden_state=last_hidden, + source_lm_head_weight=source_lm_head.weight, + avp_map=avp_map, + target_vocab_size=target_vocab_size, + ) + + bias_ms = (time.perf_counter() - bias_t0) * 1000 + nonzero_count = int((bias != 0).sum()) + bias_magnitude = float(bias[bias != 0].abs().mean()) if nonzero_count > 0 else 0.0 + + # Wire size: just the bias tensor (much smaller than text, larger than rosetta embed) + wire_bytes = bias.nelement() * bias.element_size() + + agent_time_ms = (time.perf_counter() - agent_t0) * 1000 + + agent_traces.append({ + "name": researcher.name, + "role": researcher.role, + "prompt_tokens": prompt_tokens, + "latent_steps": latent_steps, + "bias_compute_ms": bias_ms, + "wire_bytes": wire_bytes, + "agent_time_ms": agent_time_ms, + "output": "", + }) + + if verbose: + print(f" [{researcher.name}] latent steps={latent_steps}, " + f"bias compute={bias_ms:.1f}ms, " + f"nonzero={nonzero_count}/{target_vocab_size}, " + f"mean_abs_bias={bias_magnitude:.4f}") + + # Free model A KV-cache + del past_kv, hidden_states + if device == "cuda": + torch.cuda.empty_cache() + + # --- Agent 2: Solver on model B (generate with logit bias) --- + messages = build_latent_prompt(solver.role, question) + prompt_text = render_prompt(tokenizer_b, messages) + input_ids, attention_mask = tokenize_prompt(tokenizer_b, prompt_text, device) + + agent_t0 = time.perf_counter() + prompt_tokens = int(input_ids.shape[-1]) + total_prompt_tokens += prompt_tokens + + # Create logit bias processor + processor = CrossModelLogitBias( + bias=bias, + alpha=logit_bias_alpha, + confidence_threshold=logit_bias_confidence_threshold, + ) + + text, _ = generate_text( + model_b, tokenizer_b, input_ids, attention_mask, device, + past_key_values=None, # No KV-cache priming + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + logits_processor=LogitsProcessorList([processor]), + ) + + output_encoded = tokenizer_b(text, add_special_tokens=False) + output_tokens = len(output_encoded["input_ids"]) + total_output_tokens += output_tokens + agent_time_ms = (time.perf_counter() - agent_t0) * 1000 + + agent_traces.append({ + "name": solver.name, + "role": solver.role, + "prompt_tokens": prompt_tokens, + "output_tokens": output_tokens, + "agent_time_ms": agent_time_ms, + "output": text, + }) + + if verbose: + print(f" [{solver.name}] output ({len(text)} chars): {text[:200]}...") + + wall_time = time.perf_counter() - t0 + + total_tokens = total_prompt_tokens + total_latent_steps + total_output_tokens + tokens_per_sec = total_tokens / wall_time if wall_time > 0 else 0 + + gold = extract_gold(gold_solution) + prediction = extract_gsm8k_answer(agent_traces[-1]["output"]) + correct = check_correct(prediction, gold) + + return { + "question": question, + "gold": gold, + "prediction": prediction, + "raw_output": agent_traces[-1]["output"], + "correct": correct, + "wall_time": wall_time, + "total_prompt_tokens": total_prompt_tokens, + "total_latent_steps": total_latent_steps, + "total_output_tokens": total_output_tokens, + "total_tokens": total_tokens, + "tokens_per_sec": tokens_per_sec, + "peak_memory_mb": mem["peak_memory_mb"], + "bias_compute_ms": bias_ms, + "logit_bias_alpha": logit_bias_alpha, + "logit_bias_confidence_threshold": logit_bias_confidence_threshold, + "bias_nonzero_count": nonzero_count, + "bias_mean_magnitude": bias_magnitude, + "wire_bytes": wire_bytes, + "agents": agent_traces, + "mode": "logit_guided", + } + + +def run_logit_guided_benchmark( + conn_a: Any, + model_a: Any, + tokenizer_a: Any, + identity_a: Any, + model_b: Any, + tokenizer_b: Any, + device: str, + avp_map: Any, + dataset: List[Dict], + latent_steps: int = 10, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.95, + verbose: bool = False, + logit_bias_alpha: float = 0.5, + logit_bias_confidence_threshold: float = 0.8, +) -> List[Dict]: + """Run logit-guided pipeline on a list of GSM8K samples.""" + results = [] + for i, sample in enumerate(dataset): + if verbose: + print(f"\n[Logit-Guided] Sample {i + 1}/{len(dataset)}: " + f"{sample['question'][:80]}...") + + result = run_logit_guided_pipeline( + conn_a, model_a, tokenizer_a, identity_a, + model_b, tokenizer_b, device, avp_map, + question=sample["question"], + gold_solution=sample["answer"], + latent_steps=latent_steps, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + verbose=verbose, + logit_bias_alpha=logit_bias_alpha, + logit_bias_confidence_threshold=logit_bias_confidence_threshold, + ) + results.append(result) + + if verbose: + status = "CORRECT" if result["correct"] else "WRONG" + print(f" => {status} (pred={result['prediction']}, gold={result['gold']}, " + f"time={result['wall_time']:.1f}s)") + else: + correct = sum(1 for r in results if r["correct"]) + print(f" [Logit-Guided a={logit_bias_alpha}] {i + 1}/{len(dataset)} " + f"({correct}/{i + 1} correct, {result['wall_time']:.1f}s)", + flush=True) + + return results diff --git a/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py b/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py index f077d52..6857aa3 100644 --- a/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py +++ b/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py @@ -39,7 +39,8 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument( "--mode", - choices=["latent", "text", "direct", "rosetta", "text_cross_model", "both", "all"], + choices=["latent", "text", "direct", "rosetta", "logit_guided", + "text_cross_model", "both", "all"], default="all", help="Pipeline(s) to run (default: all)", ) @@ -64,6 +65,10 @@ def parse_args() -> argparse.Namespace: help="Softmax temperature for cross-model projection (default: 1.0)") parser.add_argument("--num_transfer_states", type=int, default=1, help="Number of hidden states to transfer in rosetta mode (default: 1)") + parser.add_argument("--logit_bias_alpha", type=float, default=0.5, + help="Logit bias scaling factor for logit_guided mode (default: 0.5)") + parser.add_argument("--logit_bias_confidence_threshold", type=float, default=0.8, + help="Confidence threshold for logit bias gating (default: 0.8)") return parser.parse_args() @@ -104,6 +109,8 @@ def run_benchmark(config: dict) -> dict: output_dir = config.get("output_dir") projection_temperature = config.get("projection_temperature", 1.0) num_transfer_states = config.get("num_transfer_states", 1) + logit_bias_alpha = config.get("logit_bias_alpha", 0.5) + logit_bias_confidence_threshold = config.get("logit_bias_confidence_threshold", 0.8) model_b_name = config.get("model_b", "Qwen/Qwen2.5-0.5B-Instruct") @@ -111,12 +118,13 @@ def run_benchmark(config: dict) -> dict: run_latent = mode in ("latent", "both", "all") run_text = mode in ("text", "both", "all") run_rosetta = mode in ("rosetta", "all") + run_logit_guided = mode in ("logit_guided", "all") run_text_cross_model = mode in ("text_cross_model", "all") print(f"Device: {device}") print(f"Mode: {mode}") print(f"Model A: {model_name}") - if run_rosetta or run_text_cross_model: + if run_rosetta or run_logit_guided or run_text_cross_model: print(f"Model B: {model_b_name}") print(f"Samples: {max_samples}") print(f"Latent steps: {latent_steps}") @@ -124,7 +132,8 @@ def run_benchmark(config: dict) -> dict: print(f"Temperature: {temperature}") print(f"Seed: {seed}") print(f"Pipelines: direct={run_direct}, text={run_text}, latent={run_latent}, " - f"rosetta={run_rosetta}, text_cross_model={run_text_cross_model}") + f"rosetta={run_rosetta}, logit_guided={run_logit_guided}, " + f"text_cross_model={run_text_cross_model}") print() dataset = load_dataset(max_samples) @@ -134,6 +143,7 @@ def run_benchmark(config: dict) -> dict: latent_results = None text_results = None rosetta_results = None + logit_guided_results = None text_cross_model_results = None if run_direct: @@ -182,7 +192,7 @@ def run_benchmark(config: dict) -> dict: # Load model B if needed for cross-model modes model_b = tokenizer_b = connector_b = identity_b = None - if run_rosetta or run_text_cross_model: + if run_rosetta or run_logit_guided or run_text_cross_model: model_b, tokenizer_b, connector_b, identity_b = load_model(model_b_name, device) if run_text_cross_model: @@ -235,6 +245,41 @@ def run_benchmark(config: dict) -> dict: num_transfer_states=num_transfer_states, ) + if run_logit_guided: + from benchmarks.gsm8k_2agent.pipeline_logit_guided import run_logit_guided_benchmark + from avp.rosetta.calibrate import calibrate + + print("\n" + "=" * 50) + print("Running LOGIT-GUIDED (cross-model logit bias) pipeline...") + print(f" Model A (Researcher): {model_name}") + print(f" Model B (Solver): {model_b_name}") + print(f" Alpha: {logit_bias_alpha}") + print(f" Confidence threshold: {logit_bias_confidence_threshold}") + print("=" * 50) + set_seed(seed) + + # Calibrate (reuse if already done for rosetta) + if 'avp_map' not in dir() or avp_map is None: + print("Calibrating Rosetta Stone projection...") + avp_map = calibrate( + source_model=model, target_model=model_b, + source_tokenizer=tokenizer, target_tokenizer=tokenizer_b, + device=device, + ) + print(f" Method: {avp_map.method.value}, " + f"validation_score: {avp_map.validation_score:.4f}, " + f"{avp_map.source_dim}d → {avp_map.target_dim}d") + + logit_guided_results = run_logit_guided_benchmark( + conn_a=connector, model_a=model, tokenizer_a=tokenizer, + identity_a=identity, model_b=model_b, tokenizer_b=tokenizer_b, + device=device, avp_map=avp_map, dataset=dataset, + latent_steps=latent_steps, max_new_tokens=max_new_tokens, + temperature=temperature, top_p=top_p, verbose=verbose, + logit_bias_alpha=logit_bias_alpha, + logit_bias_confidence_threshold=logit_bias_confidence_threshold, + ) + # Free model B to reclaim GPU memory if model_b is not None: del model_b, tokenizer_b, connector_b, identity_b @@ -254,6 +299,8 @@ def run_benchmark(config: dict) -> dict: modes.append(("Text", 13, text_results)) if rosetta_results is not None: modes.append(("Rosetta", 13, rosetta_results)) + if logit_guided_results is not None: + modes.append(("Logit-Guided", 13, logit_guided_results)) if text_cross_model_results is not None: modes.append(("Text Cross-Model", 16, text_cross_model_results)) @@ -267,6 +314,8 @@ def run_benchmark(config: dict) -> dict: available["latent"] = latent_results if rosetta_results is not None: available["rosetta"] = rosetta_results + if logit_guided_results is not None: + available["logit_guided"] = logit_guided_results if text_cross_model_results is not None: available["text_cross_model"] = text_cross_model_results agreement_data = compute_agreement(available) if len(available) > 1 else None @@ -289,7 +338,7 @@ def run_benchmark(config: dict) -> dict: "config": { "benchmark": "gsm8k_2agent", "model_a": model_name, - "model_b": model_b_name if (run_rosetta or run_text_cross_model) else None, + "model_b": model_b_name if (run_rosetta or run_logit_guided or run_text_cross_model) else None, "device": device, "mode": mode, "max_samples": max_samples, @@ -320,6 +369,11 @@ def run_benchmark(config: dict) -> dict: "summary": compute_stats(rosetta_results), "samples": rosetta_results, } + if logit_guided_results is not None: + output_data["logit_guided"] = { + "summary": compute_stats(logit_guided_results), + "samples": logit_guided_results, + } if text_cross_model_results is not None: output_data["text_cross_model"] = { "summary": compute_stats(text_cross_model_results), diff --git a/benchmarks/shared/generation.py b/benchmarks/shared/generation.py index 541fc8e..2682c9b 100644 --- a/benchmarks/shared/generation.py +++ b/benchmarks/shared/generation.py @@ -47,6 +47,7 @@ def generate_text( max_new_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.95, + logits_processor: Optional[Any] = None, ) -> Tuple[str, Optional[Any]]: """Generate text from input_ids, optionally with a pre-filled KV-cache. @@ -81,7 +82,7 @@ def generate_text( ) attention_mask = torch.cat([past_mask, attention_mask], dim=-1) - outputs = model.generate( + gen_kwargs = dict( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, @@ -93,6 +94,9 @@ def generate_text( past_key_values=past_key_values, cache_position=cache_position, ) + if logits_processor is not None: + gen_kwargs["logits_processor"] = logits_processor + outputs = model.generate(**gen_kwargs) generated_ids = outputs.sequences[0, prompt_len:] text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() diff --git a/src/avp/connectors/huggingface.py b/src/avp/connectors/huggingface.py index eb56470..1f78e53 100644 --- a/src/avp/connectors/huggingface.py +++ b/src/avp/connectors/huggingface.py @@ -541,10 +541,13 @@ def generate( context: Optional[AVPContext] = None, source: Optional["HuggingFaceConnector"] = None, cross_model: bool = False, + cross_model_method: str = "rosetta", max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95, do_sample: bool = True, + logit_bias_alpha: float = 0.5, + logit_bias_confidence_threshold: float = 0.8, _diagnostics: Optional[Any] = None, ) -> str: """Generate text, optionally conditioned on latent context from think(). @@ -558,10 +561,19 @@ def generate( Requires ``cross_model=True``. cross_model: Must be True to enable cross-model projection. Cross-model (Rosetta Stone) is experimental. Default False. + cross_model_method: Cross-model method to use. Options: + - ``"rosetta"`` (default): Project hidden state to target embedding + space and prime KV-cache with a single virtual token. + - ``"logit_guided"``: Map source vocabulary distribution to target + vocabulary as additive logit bias during generation. max_new_tokens: Maximum tokens to generate. temperature: Sampling temperature. top_p: Nucleus sampling threshold. do_sample: Whether to use sampling (True) or greedy decoding (False). + logit_bias_alpha: Scaling factor for logit-guided bias (0.0-2.0). + Only used when cross_model_method="logit_guided". Default 0.5. + logit_bias_confidence_threshold: When target model's max softmax + probability exceeds this, skip bias for that step. Default 0.8. _diagnostics: Internal. Pre-created TransferDiagnostics to populate. Returns: @@ -573,6 +585,8 @@ def generate( """ import torch + logit_bias_processor = None + # Cross-model: auto-project when source connector is provided if ( source is not None @@ -590,6 +604,15 @@ def generate( ) context = None source = None + elif cross_model_method == "logit_guided": + logit_bias_processor = self._compute_logit_bias_processor( + source, context, + alpha=logit_bias_alpha, + confidence_threshold=logit_bias_confidence_threshold, + _diagnostics=_diagnostics, + ) + # Don't use context for KV-cache priming in logit-guided mode + context = None else: context = self._apply_rosetta_projection( source, context, _diagnostics=_diagnostics, @@ -648,6 +671,11 @@ def generate( if past_kv is not None: gen_kwargs["past_key_values"] = past_kv gen_kwargs["cache_position"] = cache_position + if logit_bias_processor is not None: + from transformers import LogitsProcessorList + gen_kwargs["logits_processor"] = LogitsProcessorList( + [logit_bias_processor] + ) outputs = self.model.generate(**gen_kwargs) generated_ids = outputs.sequences[0, prompt_len:] @@ -673,6 +701,83 @@ def generate( return text + # --- Cross-model logit-guided decoding --- + + def _compute_logit_bias_processor( + self, + source: "HuggingFaceConnector", + context: AVPContext, + alpha: float = 0.5, + confidence_threshold: float = 0.8, + _diagnostics: Optional[Any] = None, + ) -> Any: + """Compute logit bias processor for cross-model guided decoding. + + Uses the source model's last hidden state (from think()) to compute + a vocabulary distribution, maps it through vocab overlap to the target + vocabulary, and returns a LogitsProcessor that applies it as additive + bias during generation. + + Args: + source: Source model connector. + context: AVPContext from source's think() — must have last_hidden_state. + alpha: Bias scaling factor. + confidence_threshold: Skip bias when target is this confident. + _diagnostics: Internal diagnostics object. + + Returns: + CrossModelLogitBias processor instance. + """ + if context.last_hidden_state is None: + raise ValueError( + "Logit-guided decoding requires context with last_hidden_state. " + "Use think() to produce the context." + ) + + avp_map = self._get_or_calibrate_map(source) + + # Get source lm_head + source_lm_head = source.model.get_output_embeddings() + if source_lm_head is None: + source_lm_head = getattr(source.model, "lm_head", None) + if source_lm_head is None or not hasattr(source_lm_head, "weight"): + raise RealignmentError( + "Cannot get source output embeddings (lm_head) for " + "logit-guided decoding." + ) + + from ..rosetta.logit_guided import ( + CrossModelLogitBias, + compute_cross_model_logit_bias, + ) + + target_vocab_size = self.model.config.vocab_size + bias = compute_cross_model_logit_bias( + source_hidden_state=context.last_hidden_state, + source_lm_head_weight=source_lm_head.weight, + avp_map=avp_map, + target_vocab_size=target_vocab_size, + ) + + if _diagnostics is not None: + method = getattr(avp_map, "method", None) + method_name = method.name if hasattr(method, "name") else str(method or "") + _diagnostics.transfer_mode = "logit_guided" + _diagnostics.projection_method = method_name + + logger.info( + "Logit-guided decoding: alpha=%.2f, confidence_threshold=%.2f, " + "bias nonzero=%d/%d", + alpha, confidence_threshold, + int((bias != 0).sum()), target_vocab_size, + ) + + return CrossModelLogitBias( + bias=bias, + alpha=alpha, + confidence_threshold=confidence_threshold, + ) + # --- Cross-model rosetta projection --- def _apply_rosetta_projection( diff --git a/src/avp/easy.py b/src/avp/easy.py index 1ec4a2f..b4c4a54 100644 --- a/src/avp/easy.py +++ b/src/avp/easy.py @@ -252,6 +252,7 @@ def generate( model: str, source_model: Optional[str] = None, cross_model: bool = False, + cross_model_method: str = "rosetta", steps: int = 20, context: Optional["AVPContext"] = None, store: Optional[Any] = None, @@ -259,6 +260,7 @@ def generate( prior_key: Optional[str] = None, max_new_tokens: int = 512, temperature: float = 0.7, + logit_bias_alpha: float = 0.5, collect_metrics: bool = False, debug_config: Optional["DebugConfig"] = None, ) -> Union[str, Tuple[str, "GenerateMetrics"]]: @@ -372,8 +374,10 @@ def generate( context=source_context, source=source_connector, cross_model=True, + cross_model_method=cross_model_method, max_new_tokens=max_new_tokens, temperature=temperature, + logit_bias_alpha=logit_bias_alpha, _diagnostics=diagnostics, ) generate_duration = _time.perf_counter() - t_gen diff --git a/src/avp/rosetta/logit_guided.py b/src/avp/rosetta/logit_guided.py new file mode 100644 index 0000000..847aa5b --- /dev/null +++ b/src/avp/rosetta/logit_guided.py @@ -0,0 +1,134 @@ +"""Logit-guided decoding for cross-model communication. + +Instead of compressing source model information into a single virtual token +(standard rosetta), distributes the signal across the target model's entire +autoregressive generation as additive logit biases. + +The source model's vocabulary distribution (from think()) is mapped through +vocabulary overlap to the target model's vocabulary and applied as a constant +bias during generation. +""" + +from typing import Any, Optional + +from .._torch_compat import require_torch as _require_torch + + +class CrossModelLogitBias: + """HuggingFace LogitsProcessor that applies cross-model logit bias. + + During each generation step, adds a scaled bias vector to the target model's + logits. Implements confidence gating: when the target model is already highly + confident (max probability > threshold), the bias is suppressed to avoid + pushing the model away from correct predictions. + + Compatible with transformers LogitsProcessor protocol (__call__ signature). + """ + + def __init__( + self, + bias: Any, + alpha: float = 0.5, + confidence_threshold: float = 0.8, + ): + """Initialize logit bias processor. + + Args: + bias: Tensor [target_vocab_size] — additive bias for target logits. + Should be zero-mean over mapped tokens. + alpha: Scaling factor for the bias. 0.5 = conservative (recommended + for cross-vocab mapping). Higher = stronger source influence. + confidence_threshold: When target's max softmax probability exceeds + this value, the bias is suppressed for that step. Prevents + "obvious blindness" (biasing away from correct predictions). + """ + self.bias = bias + self.alpha = alpha + self.confidence_threshold = confidence_threshold + + def __call__(self, input_ids: Any, scores: Any) -> Any: + """Apply cross-model logit bias with confidence gating. + + Args: + input_ids: [batch, seq_len] — generated token IDs so far. + scores: [batch, vocab_size] — target model logits for current step. + + Returns: + Modified logits [batch, vocab_size]. + """ + torch = _require_torch() + + bias = self.bias.to(device=scores.device, dtype=scores.dtype) + + # Confidence gating: skip bias when target is already confident + with torch.no_grad(): + probs = torch.softmax(scores, dim=-1) + max_prob = probs.max(dim=-1).values # [batch] + + # Per-batch-element mask: 1.0 if uncertain, 0.0 if confident + mask = (max_prob < self.confidence_threshold).unsqueeze(-1).float() + + return scores + self.alpha * bias * mask + + +def compute_cross_model_logit_bias( + source_hidden_state: Any, + source_lm_head_weight: Any, + avp_map: Any, + target_vocab_size: int, + temperature: float = 1.0, +) -> Any: + """Compute logit bias vector for cross-model guided decoding. + + Takes the source model's last hidden state (from think()), computes its + vocabulary distribution, and maps it through vocab overlap to the target + model's vocabulary as an additive bias. + + Args: + source_hidden_state: Tensor [1, D_src] or [D_src] — last hidden state + from source model's think(). + source_lm_head_weight: Source model's lm_head weight [vocab_size_src, D_src]. + avp_map: AVPMap with src_indices, tgt_indices, and method. + target_vocab_size: Target model's vocabulary size. + temperature: Softmax temperature for computing source distribution. + Lower = sharper bias. Default 1.0. + + Returns: + Tensor [target_vocab_size] — zero-mean additive bias for target logits. + """ + torch = _require_torch() + from ..types import ProjectionMethod + + h = source_hidden_state.detach().to(torch.float32) + if h.dim() == 2: + h = h.squeeze(0) # [D_src] + + w_src = source_lm_head_weight.detach().to(device=h.device, dtype=torch.float32) + + # Compute source log-probabilities + source_logits = torch.matmul(h, w_src.T) # [vocab_size_src] + source_log_probs = torch.log_softmax(source_logits / temperature, dim=-1) + + # Initialize target bias to zero (unmapped tokens get no bias) + target_bias = torch.zeros(target_vocab_size, device=h.device, dtype=torch.float32) + + if avp_map.method == ProjectionMethod.VOCAB_OVERLAP: + src_idx = avp_map.src_indices.to(h.device) + tgt_idx = avp_map.tgt_indices.to(h.device) + target_bias[tgt_idx] = source_log_probs[src_idx] + elif avp_map.method == ProjectionMethod.VOCAB_MEDIATED: + # Same tokenizer — direct 1:1 mapping + shared_vocab = min(source_log_probs.shape[0], target_vocab_size) + target_bias[:shared_vocab] = source_log_probs[:shared_vocab] + else: + # Ridge/Procrustes — no token-level mapping available + # Fall back to no bias (caller should use rosetta instead) + return target_bias + + # Zero-mean the mapped entries so the bias doesn't shift the distribution's + # center of mass (only nudges relative token preferences) + nonzero_mask = target_bias != 0.0 + if nonzero_mask.any(): + target_bias[nonzero_mask] -= target_bias[nonzero_mask].mean() + + return target_bias diff --git a/tests/test_logit_guided.py b/tests/test_logit_guided.py new file mode 100644 index 0000000..834a327 --- /dev/null +++ b/tests/test_logit_guided.py @@ -0,0 +1,269 @@ +"""Tests for logit-guided cross-model decoding.""" + +import pytest +import torch +from conftest import requires_torch, requires_transformers + + +# --------------------------------------------------------------------------- +# Unit tests for compute_cross_model_logit_bias +# --------------------------------------------------------------------------- + +@requires_torch +class TestComputeLogitBias: + """Test bias computation from source hidden states.""" + + def _make_avp_map(self, method, src_indices=None, tgt_indices=None, + target_norm=None): + """Create a minimal AVPMap-like object for testing.""" + from avp.rosetta.calibrate import AVPMap + + return AVPMap( + source_model_id="source", + source_hash="src_hash", + source_dim=32, + target_model_id="target", + target_hash="tgt_hash", + target_dim=32, + w_map=torch.randn(32, 32), + bias=None, + target_norm=target_norm or torch.tensor(1.0), + method=method, + anchor_count=0, + validation_score=0.0, + src_indices=src_indices, + tgt_indices=tgt_indices, + ) + + def test_vocab_overlap_bias_shape(self): + """Bias tensor has correct shape matching target vocab.""" + from avp.rosetta.logit_guided import compute_cross_model_logit_bias + from avp.types import ProjectionMethod + + hidden = torch.randn(1, 32) + lm_head_w = torch.randn(100, 32) # source vocab=100 + target_vocab = 120 + + src_idx = torch.arange(50) # 50 shared tokens + tgt_idx = torch.arange(50) + + avp_map = self._make_avp_map( + ProjectionMethod.VOCAB_OVERLAP, + src_indices=src_idx, + tgt_indices=tgt_idx, + ) + + bias = compute_cross_model_logit_bias( + hidden, lm_head_w, avp_map, target_vocab, + ) + + assert bias.shape == (target_vocab,) + + def test_vocab_overlap_unmapped_tokens_zero(self): + """Unmapped tokens should have zero bias.""" + from avp.rosetta.logit_guided import compute_cross_model_logit_bias + from avp.types import ProjectionMethod + + hidden = torch.randn(1, 32) + lm_head_w = torch.randn(100, 32) + target_vocab = 120 + + src_idx = torch.arange(50) + tgt_idx = torch.arange(50) # only tokens 0-49 mapped + + avp_map = self._make_avp_map( + ProjectionMethod.VOCAB_OVERLAP, + src_indices=src_idx, + tgt_indices=tgt_idx, + ) + + bias = compute_cross_model_logit_bias( + hidden, lm_head_w, avp_map, target_vocab, + ) + + # Tokens 50-119 should be zero (before zero-mean adjustment) + # After zero-mean, mapped tokens are shifted but unmapped stay zero + unmapped_mask = torch.ones(target_vocab, dtype=torch.bool) + unmapped_mask[tgt_idx] = False + assert (bias[unmapped_mask] == 0.0).all() + + def test_vocab_overlap_bias_is_zero_mean(self): + """Mapped tokens should have zero-mean bias.""" + from avp.rosetta.logit_guided import compute_cross_model_logit_bias + from avp.types import ProjectionMethod + + hidden = torch.randn(1, 32) + lm_head_w = torch.randn(100, 32) + target_vocab = 120 + + src_idx = torch.arange(50) + tgt_idx = torch.arange(50) + + avp_map = self._make_avp_map( + ProjectionMethod.VOCAB_OVERLAP, + src_indices=src_idx, + tgt_indices=tgt_idx, + ) + + bias = compute_cross_model_logit_bias( + hidden, lm_head_w, avp_map, target_vocab, + ) + + mapped_bias = bias[tgt_idx] + assert abs(mapped_bias.mean().item()) < 1e-5 + + def test_vocab_mediated_bias(self): + """VOCAB_MEDIATED uses direct 1:1 mapping.""" + from avp.rosetta.logit_guided import compute_cross_model_logit_bias + from avp.types import ProjectionMethod + + hidden = torch.randn(1, 32) + lm_head_w = torch.randn(100, 32) + target_vocab = 100 # same size + + avp_map = self._make_avp_map(ProjectionMethod.VOCAB_MEDIATED) + + bias = compute_cross_model_logit_bias( + hidden, lm_head_w, avp_map, target_vocab, + ) + + assert bias.shape == (target_vocab,) + # All tokens should be mapped (non-zero after zero-mean) + assert (bias != 0.0).sum() > 0 + + def test_ridge_returns_zero_bias(self): + """RIDGE method has no token-level mapping — returns zero bias.""" + from avp.rosetta.logit_guided import compute_cross_model_logit_bias + from avp.types import ProjectionMethod + + hidden = torch.randn(1, 32) + lm_head_w = torch.randn(100, 32) + target_vocab = 120 + + avp_map = self._make_avp_map(ProjectionMethod.RIDGE) + + bias = compute_cross_model_logit_bias( + hidden, lm_head_w, avp_map, target_vocab, + ) + + assert (bias == 0.0).all() + + def test_1d_hidden_state(self): + """Should handle 1D hidden state (no batch dim).""" + from avp.rosetta.logit_guided import compute_cross_model_logit_bias + from avp.types import ProjectionMethod + + hidden = torch.randn(32) # no batch dim + lm_head_w = torch.randn(100, 32) + target_vocab = 120 + + src_idx = torch.arange(50) + tgt_idx = torch.arange(50) + + avp_map = self._make_avp_map( + ProjectionMethod.VOCAB_OVERLAP, + src_indices=src_idx, + tgt_indices=tgt_idx, + ) + + bias = compute_cross_model_logit_bias( + hidden, lm_head_w, avp_map, target_vocab, + ) + + assert bias.shape == (target_vocab,) + + +# --------------------------------------------------------------------------- +# Unit tests for CrossModelLogitBias processor +# --------------------------------------------------------------------------- + +@requires_torch +class TestCrossModelLogitBias: + """Test the LogitsProcessor behavior.""" + + def test_applies_bias(self): + """Bias should modify scores.""" + from avp.rosetta.logit_guided import CrossModelLogitBias + + bias = torch.tensor([0.0, 1.0, -1.0, 0.5]) + processor = CrossModelLogitBias(bias, alpha=1.0, confidence_threshold=1.0) + + input_ids = torch.tensor([[1, 2, 3]]) + scores = torch.tensor([[0.0, 0.0, 0.0, 0.0]]) + + result = processor(input_ids, scores) + + # With uniform scores and confidence_threshold=1.0 (always apply), + # result should be scores + alpha * bias + expected = scores + bias + assert torch.allclose(result, expected, atol=1e-5) + + def test_confidence_gating_suppresses_bias(self): + """When target is confident, bias should be suppressed.""" + from avp.rosetta.logit_guided import CrossModelLogitBias + + bias = torch.tensor([0.0, 10.0, -10.0, 5.0]) + processor = CrossModelLogitBias(bias, alpha=1.0, confidence_threshold=0.5) + + input_ids = torch.tensor([[1]]) + # Very confident scores — token 0 dominates + scores = torch.tensor([[100.0, -100.0, -100.0, -100.0]]) + + result = processor(input_ids, scores) + + # Max prob ≈ 1.0 >> threshold 0.5, so bias should be suppressed + assert torch.allclose(result, scores, atol=1e-5) + + def test_confidence_gating_allows_bias_when_uncertain(self): + """When target is uncertain, bias should be applied.""" + from avp.rosetta.logit_guided import CrossModelLogitBias + + bias = torch.tensor([0.0, 1.0, -1.0, 0.5]) + processor = CrossModelLogitBias(bias, alpha=1.0, confidence_threshold=0.9) + + input_ids = torch.tensor([[1]]) + # Uniform scores — max_prob = 0.25 < 0.9 + scores = torch.zeros(1, 4) + + result = processor(input_ids, scores) + + # Should apply bias since target is uncertain + expected = scores + bias + assert torch.allclose(result, expected, atol=1e-5) + + def test_alpha_scaling(self): + """Alpha should scale the bias.""" + from avp.rosetta.logit_guided import CrossModelLogitBias + + bias = torch.tensor([1.0, 2.0, 3.0]) + processor = CrossModelLogitBias(bias, alpha=0.5, confidence_threshold=1.0) + + input_ids = torch.tensor([[1]]) + scores = torch.zeros(1, 3) + + result = processor(input_ids, scores) + expected = torch.tensor([[0.5, 1.0, 1.5]]) + assert torch.allclose(result, expected, atol=1e-5) + + def test_batch_confidence_gating(self): + """Per-batch-element gating should work.""" + from avp.rosetta.logit_guided import CrossModelLogitBias + + bias = torch.tensor([0.0, 1.0, -1.0]) + processor = CrossModelLogitBias(bias, alpha=1.0, confidence_threshold=0.5) + + input_ids = torch.tensor([[1], [2]]) + # Batch elem 0: confident (token 0 dominates) + # Batch elem 1: uncertain (uniform) + scores = torch.tensor([ + [100.0, -100.0, -100.0], + [0.0, 0.0, 0.0], + ]) + + result = processor(input_ids, scores) + + # Elem 0: bias suppressed + assert torch.allclose(result[0], scores[0], atol=1e-5) + # Elem 1: bias applied + expected_1 = scores[1] + bias + assert torch.allclose(result[1], expected_1, atol=1e-5) From b06a957286c23acb0dda812b5e1ebf13822d8257 Mon Sep 17 00:00:00 2001 From: Stanislav Date: Sat, 14 Mar 2026 01:28:32 +0000 Subject: [PATCH 02/14] Add smart routing and mid-layer injection for cross-model transfer Smart routing: Enhanced quality gate with task-type classification (math/code vs comprehension) using lexical features. Zero latency overhead. Backward compatible with existing assess_transfer() API. Mid-layer injection: Inject projected hidden states at ~75% depth via forward hook instead of layer-0 KV-cache priming. Based on Ramesh & Li (2501.14082) cross-model injection research. Both features available as cross_model_method options in HuggingFaceConnector.generate() and as benchmark pipeline modes. 47 new tests (31 smart routing + 16 mid-layer), all passing. Co-Authored-By: Claude Opus 4.6 --- benchmarks/gsm8k_2agent/pipeline_mid_layer.py | 243 ++++++++++++++++ benchmarks/gsm8k_2agent/run_gsm8k_2agent.py | 53 +++- src/avp/__init__.py | 4 + src/avp/connectors/huggingface.py | 101 ++++++- src/avp/rosetta/__init__.py | 4 +- src/avp/rosetta/mid_layer.py | 210 ++++++++++++++ src/avp/rosetta/quality.py | 264 ++++++++++++++++-- tests/test_mid_layer.py | 200 +++++++++++++ tests/test_smart_routing.py | 248 ++++++++++++++++ 9 files changed, 1300 insertions(+), 27 deletions(-) create mode 100644 benchmarks/gsm8k_2agent/pipeline_mid_layer.py create mode 100644 src/avp/rosetta/mid_layer.py create mode 100644 tests/test_mid_layer.py create mode 100644 tests/test_smart_routing.py diff --git a/benchmarks/gsm8k_2agent/pipeline_mid_layer.py b/benchmarks/gsm8k_2agent/pipeline_mid_layer.py new file mode 100644 index 0000000..77ea0e1 --- /dev/null +++ b/benchmarks/gsm8k_2agent/pipeline_mid_layer.py @@ -0,0 +1,243 @@ +"""Mid-layer injection pipeline: 2-agent chain with intermediate layer injection. + +Researcher runs on model A (latent steps -> extract hidden state -> project). +Solver runs on model B (inject at layer ~75% depth via forward hook -> generate). + +Unlike rosetta (injects projected embedding at layer 0 via inputs_embeds), +mid-layer injects at an intermediate layer, operating directly in the +semantic representation space. +""" + +import time +from typing import Any, Dict, List + +import torch + +from benchmarks.shared.generation import generate_text, render_prompt, tokenize_prompt +from benchmarks.shared.kv_utils import get_past_length +from benchmarks.shared.metrics import gpu_memory_tracker +from .agents import AGENTS, build_latent_prompt +from .evaluate import extract_gold, extract_gsm8k_answer, check_correct + + +def run_mid_layer_pipeline( + conn_a: Any, + model_a: Any, + tokenizer_a: Any, + identity_a: Any, + model_b: Any, + tokenizer_b: Any, + device: str, + avp_map: Any, + question: str, + gold_solution: str, + latent_steps: int = 10, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.95, + verbose: bool = False, + depth_ratio: float = 0.75, +) -> Dict: + """Run the 2-agent cross-model pipeline with mid-layer injection. + + Researcher (model A): latent steps -> extract hidden state -> project + Solver (model B): inject at intermediate layer via forward hook -> generate + """ + from avp.rosetta.mid_layer import ( + compute_injection_layer, + mid_layer_injection_hook, + ) + + with gpu_memory_tracker(device) as mem: + t0 = time.perf_counter() + agent_traces: List[Dict] = [] + total_prompt_tokens = 0 + total_latent_steps = 0 + total_output_tokens = 0 + + researcher = AGENTS[0] + solver = AGENTS[1] + + # --- Agent 1: Researcher on model A (latent steps) --- + messages = build_latent_prompt(researcher.role, question) + prompt_text = render_prompt(tokenizer_a, messages) + input_ids, attention_mask = tokenize_prompt(tokenizer_a, prompt_text, device) + + agent_t0 = time.perf_counter() + prompt_tokens = int(input_ids.shape[-1]) + total_prompt_tokens += prompt_tokens + total_latent_steps += latent_steps + + # Collect hidden states from all latent steps + past_kv, hidden_states = conn_a.generate_latent_steps( + input_ids, latent_steps=latent_steps, attention_mask=attention_mask, + collect_hidden_states=True, + ) + + # Use last hidden state for projection + last_hidden = hidden_states[-1].unsqueeze(0) # [1, D_src] + + # Project to target model space + proj_t0 = time.perf_counter() + projected, proj_metrics = conn_a.project_hidden_for_cross_model( + last_hidden, avp_map, return_metrics=True, + ) + projection_ms = (time.perf_counter() - proj_t0) * 1000 + + # Ensure [1, D] shape for injection + if projected.dim() == 1: + projected = projected.unsqueeze(0) + if projected.dim() == 3: + projected = projected.squeeze(0)[-1:, :] + + wire_bytes = projected.nelement() * projected.element_size() + agent_time_ms = (time.perf_counter() - agent_t0) * 1000 + + # Compute injection layer + target_num_layers = model_b.config.num_hidden_layers + injection_layer = compute_injection_layer(target_num_layers, depth_ratio) + + agent_traces.append({ + "name": researcher.name, + "role": researcher.role, + "prompt_tokens": prompt_tokens, + "latent_steps": latent_steps, + "projection_ms": projection_ms, + "wire_bytes": wire_bytes, + "agent_time_ms": agent_time_ms, + "injection_layer": injection_layer, + "target_num_layers": target_num_layers, + "output": "", + }) + + if verbose: + print(f" [{researcher.name}] latent steps={latent_steps}, " + f"projection={projection_ms:.1f}ms, " + f"inject at layer {injection_layer}/{target_num_layers} " + f"({100*injection_layer/target_num_layers:.0f}% depth)") + + # Free model A KV-cache + del past_kv, hidden_states + if device == "cuda": + torch.cuda.empty_cache() + + # --- Agent 2: Solver on model B (mid-layer injection + generate) --- + messages = build_latent_prompt(solver.role, question) + prompt_text = render_prompt(tokenizer_b, messages) + input_ids, attention_mask = tokenize_prompt(tokenizer_b, prompt_text, device) + + agent_t0 = time.perf_counter() + prompt_tokens = int(input_ids.shape[-1]) + total_prompt_tokens += prompt_tokens + + # Generate with mid-layer injection hook + inject_hidden = projected.to(device).to(model_b.dtype) + + with mid_layer_injection_hook(model_b, injection_layer, inject_hidden): + text, _ = generate_text( + model_b, tokenizer_b, input_ids, attention_mask, device, + past_key_values=None, # No KV-cache priming + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + ) + + output_encoded = tokenizer_b(text, add_special_tokens=False) + output_tokens = len(output_encoded["input_ids"]) + total_output_tokens += output_tokens + agent_time_ms = (time.perf_counter() - agent_t0) * 1000 + + agent_traces.append({ + "name": solver.name, + "role": solver.role, + "prompt_tokens": prompt_tokens, + "output_tokens": output_tokens, + "agent_time_ms": agent_time_ms, + "output": text, + }) + + if verbose: + print(f" [{solver.name}] output ({len(text)} chars): {text[:200]}...") + + wall_time = time.perf_counter() - t0 + + total_tokens = total_prompt_tokens + total_latent_steps + total_output_tokens + tokens_per_sec = total_tokens / wall_time if wall_time > 0 else 0 + + gold = extract_gold(gold_solution) + prediction = extract_gsm8k_answer(agent_traces[-1]["output"]) + correct = check_correct(prediction, gold) + + return { + "question": question, + "gold": gold, + "prediction": prediction, + "raw_output": agent_traces[-1]["output"], + "correct": correct, + "wall_time": wall_time, + "total_prompt_tokens": total_prompt_tokens, + "total_latent_steps": total_latent_steps, + "total_output_tokens": total_output_tokens, + "total_tokens": total_tokens, + "tokens_per_sec": tokens_per_sec, + "peak_memory_mb": mem["peak_memory_mb"], + "projection_overhead_ms": projection_ms, + "projection_wire_bytes": wire_bytes, + "injection_layer": injection_layer, + "depth_ratio": depth_ratio, + "hidden_state_norm": float(proj_metrics["hidden_state_norm"].mean()) if "hidden_state_norm" in proj_metrics else None, + "nearest_cos_sim": float(proj_metrics["nearest_cos_sim"].mean()) if "nearest_cos_sim" in proj_metrics else None, + "agents": agent_traces, + "mode": "mid_layer", + } + + +def run_mid_layer_benchmark( + conn_a: Any, + model_a: Any, + tokenizer_a: Any, + identity_a: Any, + model_b: Any, + tokenizer_b: Any, + device: str, + avp_map: Any, + dataset: List[Dict], + latent_steps: int = 10, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.95, + verbose: bool = False, + depth_ratio: float = 0.75, +) -> List[Dict]: + """Run mid-layer pipeline on a list of GSM8K samples.""" + results = [] + for i, sample in enumerate(dataset): + if verbose: + print(f"\n[Mid-Layer] Sample {i + 1}/{len(dataset)}: " + f"{sample['question'][:80]}...") + + result = run_mid_layer_pipeline( + conn_a, model_a, tokenizer_a, identity_a, + model_b, tokenizer_b, device, avp_map, + question=sample["question"], + gold_solution=sample["answer"], + latent_steps=latent_steps, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + verbose=verbose, + depth_ratio=depth_ratio, + ) + results.append(result) + + if verbose: + status = "CORRECT" if result["correct"] else "WRONG" + print(f" => {status} (pred={result['prediction']}, gold={result['gold']}, " + f"time={result['wall_time']:.1f}s)") + else: + correct = sum(1 for r in results if r["correct"]) + print(f" [Mid-Layer d={depth_ratio}] {i + 1}/{len(dataset)} " + f"({correct}/{i + 1} correct, {result['wall_time']:.1f}s)", + flush=True) + + return results diff --git a/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py b/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py index 6857aa3..4f93d31 100644 --- a/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py +++ b/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py @@ -40,7 +40,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--mode", choices=["latent", "text", "direct", "rosetta", "logit_guided", - "text_cross_model", "both", "all"], + "mid_layer", "text_cross_model", "both", "all"], default="all", help="Pipeline(s) to run (default: all)", ) @@ -119,12 +119,13 @@ def run_benchmark(config: dict) -> dict: run_text = mode in ("text", "both", "all") run_rosetta = mode in ("rosetta", "all") run_logit_guided = mode in ("logit_guided", "all") + run_mid_layer = mode in ("mid_layer", "all") run_text_cross_model = mode in ("text_cross_model", "all") print(f"Device: {device}") print(f"Mode: {mode}") print(f"Model A: {model_name}") - if run_rosetta or run_logit_guided or run_text_cross_model: + if run_rosetta or run_logit_guided or run_mid_layer or run_text_cross_model: print(f"Model B: {model_b_name}") print(f"Samples: {max_samples}") print(f"Latent steps: {latent_steps}") @@ -133,7 +134,7 @@ def run_benchmark(config: dict) -> dict: print(f"Seed: {seed}") print(f"Pipelines: direct={run_direct}, text={run_text}, latent={run_latent}, " f"rosetta={run_rosetta}, logit_guided={run_logit_guided}, " - f"text_cross_model={run_text_cross_model}") + f"mid_layer={run_mid_layer}, text_cross_model={run_text_cross_model}") print() dataset = load_dataset(max_samples) @@ -144,6 +145,7 @@ def run_benchmark(config: dict) -> dict: text_results = None rosetta_results = None logit_guided_results = None + mid_layer_results = None text_cross_model_results = None if run_direct: @@ -192,7 +194,7 @@ def run_benchmark(config: dict) -> dict: # Load model B if needed for cross-model modes model_b = tokenizer_b = connector_b = identity_b = None - if run_rosetta or run_logit_guided or run_text_cross_model: + if run_rosetta or run_logit_guided or run_mid_layer or run_text_cross_model: model_b, tokenizer_b, connector_b, identity_b = load_model(model_b_name, device) if run_text_cross_model: @@ -280,6 +282,38 @@ def run_benchmark(config: dict) -> dict: logit_bias_confidence_threshold=logit_bias_confidence_threshold, ) + if run_mid_layer: + from benchmarks.gsm8k_2agent.pipeline_mid_layer import run_mid_layer_benchmark + from avp.rosetta.calibrate import calibrate + + print("\n" + "=" * 50) + print("Running MID-LAYER (cross-model mid-layer injection) pipeline...") + print(f" Model A (Researcher): {model_name}") + print(f" Model B (Solver): {model_b_name}") + print(f" Depth ratio: 0.75") + print("=" * 50) + set_seed(seed) + + # Calibrate (reuse if already done) + if 'avp_map' not in dir() or avp_map is None: + print("Calibrating Rosetta Stone projection...") + avp_map = calibrate( + source_model=model, target_model=model_b, + source_tokenizer=tokenizer, target_tokenizer=tokenizer_b, + device=device, + ) + print(f" Method: {avp_map.method.value}, " + f"validation_score: {avp_map.validation_score:.4f}, " + f"{avp_map.source_dim}d -> {avp_map.target_dim}d") + + mid_layer_results = run_mid_layer_benchmark( + conn_a=connector, model_a=model, tokenizer_a=tokenizer, + identity_a=identity, model_b=model_b, tokenizer_b=tokenizer_b, + device=device, avp_map=avp_map, dataset=dataset, + latent_steps=latent_steps, max_new_tokens=max_new_tokens, + temperature=temperature, top_p=top_p, verbose=verbose, + ) + # Free model B to reclaim GPU memory if model_b is not None: del model_b, tokenizer_b, connector_b, identity_b @@ -301,6 +335,8 @@ def run_benchmark(config: dict) -> dict: modes.append(("Rosetta", 13, rosetta_results)) if logit_guided_results is not None: modes.append(("Logit-Guided", 13, logit_guided_results)) + if mid_layer_results is not None: + modes.append(("Mid-Layer", 13, mid_layer_results)) if text_cross_model_results is not None: modes.append(("Text Cross-Model", 16, text_cross_model_results)) @@ -316,6 +352,8 @@ def run_benchmark(config: dict) -> dict: available["rosetta"] = rosetta_results if logit_guided_results is not None: available["logit_guided"] = logit_guided_results + if mid_layer_results is not None: + available["mid_layer"] = mid_layer_results if text_cross_model_results is not None: available["text_cross_model"] = text_cross_model_results agreement_data = compute_agreement(available) if len(available) > 1 else None @@ -338,7 +376,7 @@ def run_benchmark(config: dict) -> dict: "config": { "benchmark": "gsm8k_2agent", "model_a": model_name, - "model_b": model_b_name if (run_rosetta or run_logit_guided or run_text_cross_model) else None, + "model_b": model_b_name if (run_rosetta or run_logit_guided or run_mid_layer or run_text_cross_model) else None, "device": device, "mode": mode, "max_samples": max_samples, @@ -374,6 +412,11 @@ def run_benchmark(config: dict) -> dict: "summary": compute_stats(logit_guided_results), "samples": logit_guided_results, } + if mid_layer_results is not None: + output_data["mid_layer"] = { + "summary": compute_stats(mid_layer_results), + "samples": mid_layer_results, + } if text_cross_model_results is not None: output_data["text_cross_model"] = { "summary": compute_stats(text_cross_model_results), diff --git a/src/avp/__init__.py b/src/avp/__init__.py index 54dc0f8..06452bf 100644 --- a/src/avp/__init__.py +++ b/src/avp/__init__.py @@ -84,7 +84,9 @@ "validate_projection", "TransferQualityConfig", "TransferQualityResult", + "TaskClassification", "assess_transfer", + "classify_task", } # Transport classes are lazy-loaded because httpx is an optional dependency. @@ -178,7 +180,9 @@ def __getattr__(name: str): "validate_projection", "TransferQualityConfig", "TransferQualityResult", + "TaskClassification", "assess_transfer", + "classify_task", # Errors "AVPError", "IncompatibleModelsError", diff --git a/src/avp/connectors/huggingface.py b/src/avp/connectors/huggingface.py index 1f78e53..7154521 100644 --- a/src/avp/connectors/huggingface.py +++ b/src/avp/connectors/huggingface.py @@ -586,6 +586,7 @@ def generate( import torch logit_bias_processor = None + mid_layer_hook_ctx = None # Cross-model: auto-project when source connector is provided if ( @@ -613,6 +614,12 @@ def generate( ) # Don't use context for KV-cache priming in logit-guided mode context = None + elif cross_model_method == "mid_layer": + mid_layer_hook_ctx = self._prepare_mid_layer_injection( + source, context, _diagnostics=_diagnostics, + ) + # Don't use KV-cache priming — injection happens via hook + context = None else: context = self._apply_rosetta_projection( source, context, _diagnostics=_diagnostics, @@ -676,7 +683,11 @@ def generate( gen_kwargs["logits_processor"] = LogitsProcessorList( [logit_bias_processor] ) - outputs = self.model.generate(**gen_kwargs) + if mid_layer_hook_ctx is not None: + with mid_layer_hook_ctx: + outputs = self.model.generate(**gen_kwargs) + else: + outputs = self.model.generate(**gen_kwargs) generated_ids = outputs.sequences[0, prompt_len:] text = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip() @@ -778,6 +789,94 @@ def _compute_logit_bias_processor( confidence_threshold=confidence_threshold, ) + # --- Cross-model mid-layer injection --- + + def _prepare_mid_layer_injection( + self, + source: "HuggingFaceConnector", + context: AVPContext, + _diagnostics: Optional[Any] = None, + ) -> Any: + """Prepare mid-layer injection hook for cross-model generation. + + Projects the source hidden state to target space (same as rosetta) + but installs a forward hook to inject at an intermediate layer + (~75% depth) instead of priming KV-cache at layer 0. + + Args: + source: Source model connector. + context: AVPContext from source's think(). + _diagnostics: Internal diagnostics object. + + Returns: + Context manager (mid_layer_injection_hook) ready to wrap generate(). + """ + if context.last_hidden_state is None: + raise ValueError( + "Mid-layer injection requires context with last_hidden_state. " + "Use think() to produce the context." + ) + + avp_map = self._get_or_calibrate_map(source) + + # Project source hidden state to target space + use_metrics = _diagnostics is not None + result = source.project_hidden_for_cross_model( + context.last_hidden_state, avp_map, + return_metrics=use_metrics, + ) + if use_metrics and isinstance(result, tuple): + projected, proj_metrics = result + if _diagnostics is not None and proj_metrics: + _diagnostics.hidden_state_norm = ( + float(proj_metrics["hidden_state_norm"].mean()) + if "hidden_state_norm" in proj_metrics else None + ) + _diagnostics.nearest_cos_sim = ( + float(proj_metrics["nearest_cos_sim"].mean()) + if "nearest_cos_sim" in proj_metrics else None + ) + else: + projected = result + + # Ensure [1, D] shape + if projected.dim() == 1: + projected = projected.unsqueeze(0) + if projected.dim() == 3: + projected = projected.squeeze(0) # [1, N, D] -> [N, D], take last + projected = projected[-1:, :] + + from ..rosetta.mid_layer import ( + compute_injection_layer, + mid_layer_injection_hook, + ) + + # Compute injection point + target_num_layers = self._identity.num_layers + injection_layer = compute_injection_layer(target_num_layers) + + method = getattr(avp_map, "method", None) + method_name = method.name if hasattr(method, "name") else str(method or "") + + if _diagnostics is not None: + _diagnostics.transfer_mode = "mid_layer" + _diagnostics.projection_method = method_name + + logger.info( + "Mid-layer injection: target layer %d/%d (%.0f%% depth), " + "projected shape %s", + injection_layer, target_num_layers, + 100 * injection_layer / target_num_layers, + tuple(projected.shape), + ) + + # Return the context manager — caller wraps model.generate() with it + return mid_layer_injection_hook( + model=self.model, + injection_layer=injection_layer, + projected_hidden=projected, + ) + # --- Cross-model rosetta projection --- def _apply_rosetta_projection( diff --git a/src/avp/rosetta/__init__.py b/src/avp/rosetta/__init__.py index 6cb3c9d..417ea14 100644 --- a/src/avp/rosetta/__init__.py +++ b/src/avp/rosetta/__init__.py @@ -6,7 +6,7 @@ from .calibrate import AVPMap, DEFAULT_ANCHORS, calibrate from .project import apply_cross_model_projection, vocab_overlap_projection, vocabulary_mediated_projection -from .quality import TransferQualityConfig, TransferQualityResult, assess_transfer +from .quality import TaskClassification, TransferQualityConfig, TransferQualityResult, assess_transfer, classify_task from .registry import find_map, load_map, map_id, save_map from .validate import ValidationConfig, ValidationResult, validate_projection @@ -22,7 +22,9 @@ # Quality gate "TransferQualityConfig", "TransferQualityResult", + "TaskClassification", "assess_transfer", + "classify_task", # Registry "save_map", "load_map", diff --git a/src/avp/rosetta/mid_layer.py b/src/avp/rosetta/mid_layer.py new file mode 100644 index 0000000..876ce19 --- /dev/null +++ b/src/avp/rosetta/mid_layer.py @@ -0,0 +1,210 @@ +"""Mid-layer injection for cross-model latent transfer. + +Instead of injecting projected hidden states at layer 0 (via inputs_embeds), +injects at an intermediate layer (~75% depth) using a forward hook. This +bypasses the early embedding/position-encoding layers and operates directly +in the semantic representation space. + +Based on: +- Ramesh & Li (2501.14082): Cross-model hidden state injection at intermediate + layers, up to 27% improvement over text, cross-family confirmed. +- Proportional depth mapping (2504.08775): Layer L_a/N_a maps to L_b/N_b + across architectures (p < 0.005 for 24 LLMs from 1B-70B). + +Key design decisions: +- REPLACE, not sum/mean (Ramesh & Li found sum/mean produce OOD norms) +- Proportional depth mapping: injection_layer = int(N_tgt * extraction_ratio) +- Forward hook scoped to prefill only (removed after first forward pass) +""" + +import logging +from contextlib import contextmanager +from typing import Any, Optional, Tuple + +logger = logging.getLogger(__name__) + +# Default extraction/injection depth ratio (validated at 0.75 = ~75% depth) +DEFAULT_DEPTH_RATIO = 0.75 + + +def compute_extraction_layer(num_layers: int, depth_ratio: float = DEFAULT_DEPTH_RATIO) -> int: + """Compute the layer index to extract hidden states from. + + Args: + num_layers: Total number of transformer layers in the model. + depth_ratio: Fraction of depth to extract from (0.0=first, 1.0=last). + + Returns: + Layer index (0-indexed). + """ + layer = int(num_layers * depth_ratio) + return min(layer, num_layers - 1) + + +def compute_injection_layer(num_layers: int, depth_ratio: float = DEFAULT_DEPTH_RATIO) -> int: + """Compute the layer index to inject hidden states into. + + Uses proportional depth mapping: if source extracted from 75% depth, + inject at 75% depth of target model (even if different number of layers). + + Args: + num_layers: Total number of transformer layers in target model. + depth_ratio: Fraction of depth to inject at. + + Returns: + Layer index (0-indexed). + """ + layer = int(num_layers * depth_ratio) + return min(layer, num_layers - 1) + + +def extract_mid_layer_hidden( + model_outputs: Any, + extraction_layer: int, +) -> Any: + """Extract hidden state from an intermediate layer of model outputs. + + Args: + model_outputs: Model output dict with hidden_states. + extraction_layer: Layer index to extract from. + + Returns: + Hidden state tensor [B, D] from the specified layer's last token. + """ + # hidden_states is a tuple of (num_layers + 1) tensors, each [B, seq, D] + # Index 0 = embedding output, index i = output of layer i + hidden_states = model_outputs.hidden_states + if extraction_layer + 1 >= len(hidden_states): + extraction_layer = len(hidden_states) - 2 # -1 is last layer output + # +1 because index 0 is embedding layer output, index 1 is layer 0 output + layer_hidden = hidden_states[extraction_layer + 1] + return layer_hidden[:, -1, :] # [B, D] — last token only + + +def _get_decoder_layers(model: Any): + """Get the list of decoder layers from a HuggingFace model. + + Handles different model architectures (Llama, Qwen, GPT-2, etc.). + """ + # Try common attribute paths + inner = getattr(model, "model", None) + if inner is not None: + layers = getattr(inner, "layers", None) + if layers is not None: + return layers + + # GPT-2 style + transformer = getattr(model, "transformer", None) + if transformer is not None: + h = getattr(transformer, "h", None) + if h is not None: + return h + + raise AttributeError( + f"Cannot find decoder layers in model {type(model).__name__}. " + "Expected model.model.layers or model.transformer.h" + ) + + +@contextmanager +def mid_layer_injection_hook( + model: Any, + injection_layer: int, + projected_hidden: Any, +): + """Context manager that installs a forward hook to replace hidden states + at a specific layer during the first forward pass (prefill). + + The hook fires once and then removes itself, so it only affects the + initial prefill pass, not subsequent autoregressive generation steps. + + Args: + model: HuggingFace model to hook into. + injection_layer: Layer index to inject at. + projected_hidden: Tensor [1, D] or [B, D] to replace the last token's + hidden state with. + + Yields: + None. The hook is active during the context. + """ + import torch + + layers = _get_decoder_layers(model) + target_layer = layers[injection_layer] + + fired = [False] # mutable flag for closure + + def hook_fn(module, input, output): + if fired[0]: + return output + + fired[0] = True + + # Decoder layer output is a tuple: (hidden_states, ...) or just hidden_states + if isinstance(output, tuple): + hidden = output[0] # [B, seq_len, D] + else: + hidden = output + + # Replace last token's hidden state with projected source hidden state + injection = projected_hidden.to(device=hidden.device, dtype=hidden.dtype) + if injection.dim() == 1: + injection = injection.unsqueeze(0) # [D] -> [1, D] + + # Clone to avoid in-place modification + modified = hidden.clone() + modified[:, -1, :] = injection # Replace last position + + if isinstance(output, tuple): + return (modified,) + output[1:] + return modified + + handle = target_layer.register_forward_hook(hook_fn) + try: + yield + finally: + handle.remove() + + +def project_for_mid_layer( + source_hidden: Any, + avp_map: Any, + source_model: Any, + target_model: Any, + target_num_layers: int, + injection_depth_ratio: float = DEFAULT_DEPTH_RATIO, +) -> Tuple[Any, int]: + """Project source hidden state for mid-layer injection. + + Unlike rosetta (which projects to target embedding space for layer-0 inputs_embeds), + mid-layer projects to the target's intermediate representation space. Since we don't + have a direct map between intermediate spaces, we use the same vocab-mediated/overlap + projection but normalize to the target layer's activation norm instead of the + embedding norm. + + Args: + source_hidden: Source hidden state [1, D_src] or [D_src]. + avp_map: AVPMap with projection data. + source_model: Source HuggingFace model. + target_model: Target HuggingFace model. + target_num_layers: Number of layers in target model. + injection_depth_ratio: Depth ratio for injection point. + + Returns: + Tuple of (projected_hidden [1, D_tgt], injection_layer_index). + """ + import torch + from .project import apply_cross_model_projection + + injection_layer = compute_injection_layer(target_num_layers, injection_depth_ratio) + + # Use standard vocab-mediated/overlap projection + projected = apply_cross_model_projection( + source_hidden, avp_map, source_model, target_model, + ) + + # Ensure correct shape [1, D] + if projected.dim() == 1: + projected = projected.unsqueeze(0) + + return projected, injection_layer diff --git a/src/avp/rosetta/quality.py b/src/avp/rosetta/quality.py index 2ab6f0e..d4ab40f 100644 --- a/src/avp/rosetta/quality.py +++ b/src/avp/rosetta/quality.py @@ -5,12 +5,17 @@ Primary signal: source prompt token count. Single-embedding rosetta works for short structured prompts (<300 tokens) but degrades significantly for -longer prompts. Validated across 4 benchmarks × 3 rosetta configurations: - - GSM8K cross-family: 65% at <300 tokens → 41% at 300-500 tokens - - HumanEval same-family: 61% at <300 → 40% at 300-500 → 19% at 500+ - - Even reverse rosetta (strong config): 87% → 84% → 55% at 500+ +longer prompts. Validated across 4 benchmarks x 3 rosetta configurations: + - GSM8K cross-family: 65% at <300 tokens, 41% at 300-500 tokens + - HumanEval same-family: 61% at <300, 40% at 300-500, 19% at 500+ + - Even reverse rosetta (strong config): 87%, 84%, 55% at 500+ Prompt length is a zero-cost proxy for information density. +Enhanced signal (v2): task-type classification via prompt features. Detects +math/code markers (structured tasks where rosetta works) vs comprehension +patterns (multi-paragraph context + question where text is needed). Zero +latency overhead, zero dependencies. See classify_task(). + Secondary signal (opt-in): effective rank ratio of hidden states. High effective rank means information is spread across many dimensions, which a single embedding cannot capture. Off by default; included for future @@ -20,7 +25,12 @@ from avp.rosetta.quality import assess_transfer + # Token-count only (backward compatible) result = assess_transfer(prompt_tokens=len(input_ids[0])) + + # Enhanced: with prompt text for task-type classification + result = assess_transfer(prompt_tokens=len(input_ids[0]), prompt_text=prompt) + if result.recommend_latent: # proceed with rosetta projection ... @@ -29,10 +39,47 @@ ... """ -from dataclasses import dataclass +import re +from dataclasses import dataclass, field from typing import Any, Optional +# Compiled patterns for task classification (module-level for zero overhead) +_MATH_MARKERS = re.compile( + r'\\boxed|\\frac|\\sqrt|\\sum|\\int|\\cdot|' + r'\d+\s*[+\-*/=]\s*\d+|' + r'\$[^$]+\$' +) +_CODE_MARKERS = re.compile( + r'\bdef\s+\w+|' + r'\bclass\s+\w+|' + r'\bimport\s+\w+|' + r'\bfunction\s+\w+|' + r'\breturn\s+|' + r'```|' + r'>>>|' + r'\bfor\s+\w+\s+in\s+' +) +_COMPREHENSION_STARTERS = re.compile( + r'(?:^|[\n.?!]\s*)(?:who\s+(?:is|was|were|are|did)|' + r'what\s+(?:is|was|were|are|did|does)|' + r'when\s+(?:did|was|were|is)|' + r'where\s+(?:did|was|were|is)|' + r'why\s+(?:did|was|were|is)|' + r'how\s+did|' + r'according\s+to|based\s+on|which\s+of\s+the\s+following|' + r'in\s+the\s+(?:passage|text|article|context))', + re.IGNORECASE, +) +_STRUCTURED_STARTERS = re.compile( + r'(?:^|[\n.]\s*)(?:solve|compute|calculate|find\s+the|evaluate|simplify|' + r'how\s+much|how\s+many|' + r'implement|write\s+(?:a\s+)?(?:function|program|code|class)|' + r'debug|fix\s+(?:the|this)|refactor)', + re.IGNORECASE, +) + + @dataclass class TransferQualityConfig: """Configuration for the per-transfer quality gate. @@ -42,6 +89,9 @@ class TransferQualityConfig: transfer is recommended. Default 300, validated across GSM8K, HumanEval, and HotpotQA rosetta benchmarks. Above 300 tokens, accuracy drops 20-30pp across all configurations. + use_task_classification: Whether to use prompt text features + for task-type classification. Requires prompt_text to be + passed to assess_transfer(). On by default. check_effective_rank: Whether to compute effective rank ratio of hidden states as a secondary signal. Requires torch and a hidden_states tensor. Off by default. @@ -51,10 +101,26 @@ class TransferQualityConfig: """ max_prompt_tokens: int = 300 + use_task_classification: bool = True check_effective_rank: bool = False max_effective_rank_ratio: float = 0.8 +@dataclass +class TaskClassification: + """Result of prompt task-type classification. + + Attributes: + task_type: 'structured' or 'comprehension'. + score: Numeric score (positive=structured, negative=comprehension). + features: Dict of individual feature contributions. + """ + + task_type: str + score: int + features: dict = field(default_factory=dict) + + @dataclass class TransferQualityResult: """Result of a per-transfer quality assessment. @@ -66,12 +132,112 @@ class TransferQualityResult: reason: Human-readable explanation of the recommendation. effective_rank_ratio: Effective rank ratio of hidden states, or None if not computed. + task_classification: Task-type classification result, or None + if prompt_text was not provided. """ recommend_latent: bool prompt_tokens: int reason: str effective_rank_ratio: Optional[float] = None + task_classification: Optional[TaskClassification] = None + + +def classify_task(prompt_text: str, prompt_tokens: int = 0) -> TaskClassification: + """Classify a prompt as 'structured' or 'comprehension'. + + Uses lexical and structural features of the prompt text to determine + whether rosetta projection is likely to work (structured tasks like + math/code) or whether text mode is needed (comprehension tasks with + long contexts). + + Scoring: positive = structured, negative = comprehension. + Threshold: score > 0 = structured. + + Features (5 signals, validated against AVP benchmark data): + 1. Token count: >500 penalizes, <200 rewards + 2. Digit density: high digit ratio = math/structured + 3. Math/code markers: regex patterns for code/math syntax + 4. Comprehension question patterns: who/what/when/according-to + 5. Multi-paragraph context: 3+ paragraphs = comprehension + + Args: + prompt_text: The raw prompt text. + prompt_tokens: Token count (if known, used for length signal). + + Returns: + TaskClassification with task_type, score, and feature breakdown. + """ + score = 0 + features = {} + + # Feature 1: Token count (existing signal, enhanced) + if prompt_tokens > 500: + features["token_count"] = -2 + score -= 2 + elif prompt_tokens > 300: + features["token_count"] = -1 + score -= 1 + elif 0 < prompt_tokens < 200: + features["token_count"] = 1 + score += 1 + else: + features["token_count"] = 0 + + # Feature 2: Digit density + digits = sum(c.isdigit() for c in prompt_text) + text_len = max(len(prompt_text), 1) + digit_ratio = digits / text_len + if digit_ratio > 0.05: + features["digit_density"] = 2 + score += 2 + elif digit_ratio > 0.02: + features["digit_density"] = 1 + score += 1 + else: + features["digit_density"] = 0 + + # Feature 3: Math/code markers + math_hits = len(_MATH_MARKERS.findall(prompt_text)) + code_hits = len(_CODE_MARKERS.findall(prompt_text)) + marker_score = 0 + if math_hits >= 2: + marker_score += 2 + elif math_hits >= 1: + marker_score += 1 + if code_hits >= 2: + marker_score += 2 + elif code_hits >= 1: + marker_score += 1 + marker_score = min(marker_score, 3) # cap at 3 + features["markers"] = marker_score + score += marker_score + + # Feature 4: Comprehension question patterns + comp_hits = len(_COMPREHENSION_STARTERS.findall(prompt_text)) + struct_hits = len(_STRUCTURED_STARTERS.findall(prompt_text)) + if comp_hits > 0 and struct_hits == 0: + features["question_type"] = -2 + score -= 2 + elif struct_hits > 0 and comp_hits == 0: + features["question_type"] = 1 + score += 1 + else: + features["question_type"] = 0 + + # Feature 5: Multi-paragraph context (3+ double-newlines) + paragraphs = prompt_text.count("\n\n") + if paragraphs >= 3: + features["paragraphs"] = -2 + score -= 2 + elif paragraphs >= 2: + features["paragraphs"] = -1 + score -= 1 + else: + features["paragraphs"] = 0 + + task_type = "structured" if score > 0 else "comprehension" + return TaskClassification(task_type=task_type, score=score, features=features) def _compute_effective_rank_ratio(hidden_states: Any) -> float: @@ -94,7 +260,7 @@ def _compute_effective_rank_ratio(hidden_states: Any) -> float: if not isinstance(t, torch.Tensor): t = torch.tensor(t, dtype=torch.float32) - # Squeeze batch dim if present: [1, seq, D] → [seq, D] + # Squeeze batch dim if present: [1, seq, D] -> [seq, D] if t.ndim == 3: t = t.squeeze(0) @@ -109,7 +275,7 @@ def _compute_effective_rank_ratio(hidden_states: Any) -> float: if s_sum == 0: return 0.0 - # Normalized singular values → probability distribution + # Normalized singular values -> probability distribution p = s / s_sum # Shannon entropy of the distribution @@ -127,15 +293,23 @@ def _compute_effective_rank_ratio(hidden_states: Any) -> float: def assess_transfer( prompt_tokens: int, + prompt_text: Optional[str] = None, hidden_states: Any = None, config: Optional[TransferQualityConfig] = None, ) -> TransferQualityResult: """Assess whether a cross-model transfer should use latent or JSON. - This is advisory — the caller decides how to act on the result. + This is advisory -- the caller decides how to act on the result. + + When prompt_text is provided and config.use_task_classification is True, + task-type classification enhances the token-count heuristic. A prompt + classified as 'comprehension' recommends text even if under the token + limit; a prompt classified as 'structured' with strong signals may + recommend latent even if slightly over the token limit. Args: prompt_tokens: Number of tokens in the source prompt. + prompt_text: Optional raw prompt text for task-type classification. hidden_states: Optional tensor of hidden states, shape [seq_len, D] or [1, seq_len, D]. Only used when config.check_effective_rank=True. @@ -148,19 +322,67 @@ def assess_transfer( config = TransferQualityConfig() effective_rank_ratio: Optional[float] = None + task_cls: Optional[TaskClassification] = None - # Primary gate: prompt token count - if prompt_tokens > config.max_prompt_tokens: - return TransferQualityResult( - recommend_latent=False, - prompt_tokens=prompt_tokens, - reason=( - f"prompt_tokens={prompt_tokens} exceeds " - f"max_prompt_tokens={config.max_prompt_tokens}; " - f"single embedding unlikely to capture sufficient information" - ), - effective_rank_ratio=None, - ) + # Task classification (if prompt text available) + if prompt_text is not None and config.use_task_classification: + task_cls = classify_task(prompt_text, prompt_tokens) + + # Combined gate: token count + task classification + if task_cls is not None: + # Strong comprehension signal overrides even short prompts + if task_cls.task_type == "comprehension" and task_cls.score <= -3: + return TransferQualityResult( + recommend_latent=False, + prompt_tokens=prompt_tokens, + reason=( + f"task classified as comprehension (score={task_cls.score}); " + f"single embedding unlikely to capture context" + ), + task_classification=task_cls, + ) + + # Strong structured signal allows slightly longer prompts + if task_cls.task_type == "structured" and task_cls.score >= 3: + # Allow up to 1.5x the normal token limit for strongly structured + extended_limit = int(config.max_prompt_tokens * 1.5) + if prompt_tokens > extended_limit: + return TransferQualityResult( + recommend_latent=False, + prompt_tokens=prompt_tokens, + reason=( + f"prompt_tokens={prompt_tokens} exceeds extended limit " + f"{extended_limit} (structured task, score={task_cls.score})" + ), + task_classification=task_cls, + ) + # Within extended limit -- recommend latent + else: + # Moderate signals: use standard token limit + if prompt_tokens > config.max_prompt_tokens: + return TransferQualityResult( + recommend_latent=False, + prompt_tokens=prompt_tokens, + reason=( + f"prompt_tokens={prompt_tokens} exceeds " + f"max_prompt_tokens={config.max_prompt_tokens}; " + f"task_type={task_cls.task_type} (score={task_cls.score})" + ), + task_classification=task_cls, + ) + else: + # No task classification -- fall back to token count only + if prompt_tokens > config.max_prompt_tokens: + return TransferQualityResult( + recommend_latent=False, + prompt_tokens=prompt_tokens, + reason=( + f"prompt_tokens={prompt_tokens} exceeds " + f"max_prompt_tokens={config.max_prompt_tokens}; " + f"single embedding unlikely to capture sufficient information" + ), + effective_rank_ratio=None, + ) # Secondary gate: effective rank (opt-in) if config.check_effective_rank and hidden_states is not None: @@ -175,6 +397,7 @@ def assess_transfer( f"information too spread for single-embedding transfer" ), effective_rank_ratio=effective_rank_ratio, + task_classification=task_cls, ) return TransferQualityResult( @@ -182,4 +405,5 @@ def assess_transfer( prompt_tokens=prompt_tokens, reason="transfer within quality thresholds", effective_rank_ratio=effective_rank_ratio, + task_classification=task_cls, ) diff --git a/tests/test_mid_layer.py b/tests/test_mid_layer.py new file mode 100644 index 0000000..710cc9a --- /dev/null +++ b/tests/test_mid_layer.py @@ -0,0 +1,200 @@ +"""Tests for mid-layer injection cross-model transfer.""" + +import pytest +import torch + +from avp.rosetta.mid_layer import ( + DEFAULT_DEPTH_RATIO, + compute_extraction_layer, + compute_injection_layer, + extract_mid_layer_hidden, + mid_layer_injection_hook, + _get_decoder_layers, +) + + +class TestLayerComputation: + """Tests for extraction/injection layer computation.""" + + def test_default_depth_ratio(self): + assert DEFAULT_DEPTH_RATIO == 0.75 + + def test_extraction_layer_28_layers(self): + """Qwen 7B has 28 layers -> extract from layer 21.""" + layer = compute_extraction_layer(28) + assert layer == 21 # int(28 * 0.75) + + def test_extraction_layer_32_layers(self): + """Llama 3B has 32 layers -> extract from layer 24.""" + layer = compute_extraction_layer(32) + assert layer == 24 + + def test_injection_layer_proportional(self): + """28 source layers, 32 target layers -> inject at 24.""" + layer = compute_injection_layer(32) + assert layer == 24 + + def test_custom_depth_ratio(self): + layer = compute_extraction_layer(28, depth_ratio=0.5) + assert layer == 14 + + def test_depth_ratio_zero(self): + layer = compute_extraction_layer(28, depth_ratio=0.0) + assert layer == 0 + + def test_depth_ratio_one(self): + """Ratio 1.0 should clamp to last layer.""" + layer = compute_extraction_layer(28, depth_ratio=1.0) + assert layer == 27 # min(28, 27) + + def test_small_model(self): + layer = compute_extraction_layer(4, depth_ratio=0.75) + assert layer == 3 + + +class TestExtractMidLayerHidden: + """Tests for mid-layer hidden state extraction.""" + + def test_extracts_correct_layer(self): + """Verify extraction returns the right layer's last token.""" + # Simulate hidden_states: tuple of (num_layers + 1) tensors + # Index 0 = embedding, index i+1 = layer i output + num_layers = 10 + batch_size = 1 + seq_len = 5 + hidden_dim = 32 + + hidden_states = tuple( + torch.randn(batch_size, seq_len, hidden_dim) * (i + 1) + for i in range(num_layers + 1) + ) + + class MockOutputs: + pass + + outputs = MockOutputs() + outputs.hidden_states = hidden_states + + extraction_layer = 7 + result = extract_mid_layer_hidden(outputs, extraction_layer) + + assert result.shape == (batch_size, hidden_dim) + # Should match hidden_states[extraction_layer + 1][:, -1, :] + expected = hidden_states[extraction_layer + 1][:, -1, :] + assert torch.allclose(result, expected) + + def test_clamps_to_valid_range(self): + """Layer beyond range should clamp.""" + hidden_states = tuple( + torch.randn(1, 5, 16) for _ in range(6) # 5 layers + embedding + ) + + class MockOutputs: + pass + + outputs = MockOutputs() + outputs.hidden_states = hidden_states + + # Request layer 10 but only 5 layers exist + result = extract_mid_layer_hidden(outputs, 10) + assert result.shape == (1, 16) + + +class TestMidLayerInjectionHook: + """Tests for the forward hook context manager.""" + + def test_hook_modifies_output(self): + """Test that the hook replaces the last token's hidden state.""" + # Simple linear model to test hooking + layer = torch.nn.Linear(16, 16, bias=False) + # Set to identity so we can verify the hook's effect + with torch.no_grad(): + layer.weight.copy_(torch.eye(16)) + + class SimpleModel: + def __init__(self): + self.model = type("inner", (), {"layers": torch.nn.ModuleList([layer])})() + + model = SimpleModel() + projected = torch.ones(1, 16) * 42.0 # Distinctive value + + with mid_layer_injection_hook(model, 0, projected): + # Simulate a forward pass through the hooked layer + input_tensor = torch.randn(1, 5, 16) + output = layer(input_tensor) + # output should have last position replaced + + # After context exit, hook should be removed + assert len(layer._forward_hooks) == 0 + + def test_hook_fires_only_once(self): + """The hook should fire only on the first forward pass.""" + layer = torch.nn.Linear(16, 16, bias=False) + with torch.no_grad(): + layer.weight.copy_(torch.eye(16)) + + class SimpleModel: + def __init__(self): + self.model = type("inner", (), {"layers": torch.nn.ModuleList([layer])})() + + model = SimpleModel() + projected = torch.ones(1, 16) * 42.0 + + with mid_layer_injection_hook(model, 0, projected): + # First forward pass — hook should fire + input1 = torch.randn(1, 5, 16) + out1 = layer(input1) + + # Second forward pass — hook should NOT fire + input2 = torch.randn(1, 3, 16) + out2 = layer(input2) + + # First pass: last token should be replaced + # Second pass: should be unmodified + # (The hook returns the modified output only on first call) + + assert len(layer._forward_hooks) == 0 + + def test_hook_cleanup_on_exception(self): + """Hook should be removed even if an exception occurs.""" + layer = torch.nn.Linear(16, 16, bias=False) + + class SimpleModel: + def __init__(self): + self.model = type("inner", (), {"layers": torch.nn.ModuleList([layer])})() + + model = SimpleModel() + projected = torch.ones(1, 16) + + with pytest.raises(RuntimeError): + with mid_layer_injection_hook(model, 0, projected): + raise RuntimeError("test error") + + # Hook should still be removed + assert len(layer._forward_hooks) == 0 + + +class TestGetDecoderLayers: + """Tests for model layer discovery.""" + + def test_llama_style_model(self): + """Models with model.model.layers (Llama, Qwen).""" + layers = torch.nn.ModuleList([torch.nn.Linear(16, 16) for _ in range(4)]) + inner = type("inner", (), {"layers": layers})() + model = type("model", (), {"model": inner})() + result = _get_decoder_layers(model) + assert result is layers + + def test_gpt2_style_model(self): + """Models with model.transformer.h (GPT-2).""" + h = torch.nn.ModuleList([torch.nn.Linear(16, 16) for _ in range(4)]) + transformer = type("transformer", (), {"h": h})() + model = type("model", (), {"model": None, "transformer": transformer})() + result = _get_decoder_layers(model) + assert result is h + + def test_unknown_model_raises(self): + """Models without known layer paths should raise.""" + model = type("model", (), {"model": None, "transformer": None})() + with pytest.raises(AttributeError, match="Cannot find decoder layers"): + _get_decoder_layers(model) diff --git a/tests/test_smart_routing.py b/tests/test_smart_routing.py new file mode 100644 index 0000000..a00a4e8 --- /dev/null +++ b/tests/test_smart_routing.py @@ -0,0 +1,248 @@ +"""Tests for smart routing / task classification in quality gate.""" + +import pytest + +from avp.rosetta.quality import ( + TaskClassification, + TransferQualityConfig, + TransferQualityResult, + assess_transfer, + classify_task, +) + + +class TestClassifyTask: + """Tests for classify_task() prompt classification.""" + + def test_math_problem_classified_as_structured(self): + prompt = "Solve: 24 * 17 + 3 = ?" + result = classify_task(prompt, prompt_tokens=50) + assert result.task_type == "structured" + assert result.score > 0 + + def test_code_problem_classified_as_structured(self): + prompt = "def fibonacci(n):\n # implement this function\n return result" + result = classify_task(prompt, prompt_tokens=30) + assert result.task_type == "structured" + assert result.score > 0 + + def test_comprehension_classified_as_comprehension(self): + prompt = ( + "The industrial revolution began in Britain in the late 18th century. " + "It was a period of great change in manufacturing, mining, and transport.\n\n" + "The textile industry was the first to adopt new methods of production.\n\n" + "Steam power was central to the changes.\n\n" + "Who started the industrial revolution?" + ) + result = classify_task(prompt, prompt_tokens=800) + assert result.task_type == "comprehension" + assert result.score < 0 + + def test_high_digit_density_is_structured(self): + prompt = "Calculate 123 + 456 * 789 - 321 / 654 = ?" + result = classify_task(prompt, prompt_tokens=20) + assert result.features["digit_density"] > 0 + assert result.task_type == "structured" + + def test_no_digits_no_digit_bonus(self): + prompt = "Tell me about the history of France" + result = classify_task(prompt, prompt_tokens=10) + assert result.features["digit_density"] == 0 + + def test_long_prompt_penalized(self): + prompt = "Some text" + result = classify_task(prompt, prompt_tokens=600) + assert result.features["token_count"] == -2 + + def test_short_prompt_rewarded(self): + prompt = "Solve: 2 + 2" + result = classify_task(prompt, prompt_tokens=50) + assert result.features["token_count"] == 1 + + def test_math_markers_detected(self): + prompt = r"Find \frac{x}{y} where \sqrt{x} = 5" + result = classify_task(prompt, prompt_tokens=20) + assert result.features["markers"] > 0 + + def test_code_markers_detected(self): + prompt = "```python\ndef foo():\n return 42\n```" + result = classify_task(prompt, prompt_tokens=20) + assert result.features["markers"] > 0 + + def test_comprehension_questions_detected(self): + prompt = "Based on the passage above, who was the first president?" + result = classify_task(prompt, prompt_tokens=100) + assert result.features["question_type"] < 0 + + def test_structured_starters_detected(self): + prompt = "Solve the following equation: 3x + 5 = 20" + result = classify_task(prompt, prompt_tokens=50) + assert result.features["question_type"] > 0 + + def test_multi_paragraph_penalized(self): + prompt = "Paragraph one.\n\nParagraph two.\n\nParagraph three.\n\nQuestion?" + result = classify_task(prompt, prompt_tokens=100) + assert result.features["paragraphs"] < 0 + + def test_returns_task_classification_dataclass(self): + result = classify_task("test prompt", prompt_tokens=50) + assert isinstance(result, TaskClassification) + assert result.task_type in ("structured", "comprehension") + assert isinstance(result.score, int) + assert isinstance(result.features, dict) + + def test_gsm8k_style_prompt(self): + """GSM8K-style math prompt should be structured.""" + prompt = ( + "Janet's ducks lay 16 eggs per day. She eats three for breakfast " + "every morning and bakes muffins for her friends every day with four. " + "She sells the remainder at the farmers' market daily for $2 per " + "fresh duck egg. How much in dollars does she make every day at the " + "farmers' market?" + ) + result = classify_task(prompt, prompt_tokens=80) + assert result.task_type == "structured" + + def test_hotpotqa_style_prompt(self): + """HotpotQA-style multi-paragraph QA should be comprehension.""" + prompt = ( + "Walter Elias Disney was an American entrepreneur, animator, voice " + "actor and film producer. A pioneer of the American animation " + "industry, he introduced several developments in the production of " + "cartoons.\n\n" + "As a film producer, Disney holds the record for most Academy Awards " + "earned by an individual, having won 22 Oscars from 59 nominations.\n\n" + "He was presented with two Golden Globe Special Achievement Awards " + "and an Emmy Award, among other honors.\n\n" + "Several of his films are included in the National Film Registry by " + "the Library of Congress.\n\n" + "Based on the passage, how many Academy Awards did Disney win?" + ) + result = classify_task(prompt, prompt_tokens=1200) + assert result.task_type == "comprehension" + + def test_empty_prompt(self): + result = classify_task("", prompt_tokens=0) + assert isinstance(result, TaskClassification) + + def test_zero_tokens_no_token_bonus(self): + result = classify_task("test", prompt_tokens=0) + assert result.features["token_count"] == 0 + + +class TestAssessTransferWithPromptText: + """Tests for enhanced assess_transfer() with prompt text.""" + + def test_backward_compatible_no_prompt_text(self): + """Without prompt_text, behaves exactly like v1.""" + result = assess_transfer(prompt_tokens=200) + assert result.recommend_latent is True + assert result.task_classification is None + + def test_backward_compatible_over_limit(self): + result = assess_transfer(prompt_tokens=400) + assert result.recommend_latent is False + assert result.task_classification is None + + def test_structured_prompt_recommended_latent(self): + result = assess_transfer( + prompt_tokens=100, + prompt_text="Solve: 24 * 17 + 3 = ?", + ) + assert result.recommend_latent is True + assert result.task_classification is not None + assert result.task_classification.task_type == "structured" + + def test_comprehension_prompt_blocks_latent(self): + """Strong comprehension signal blocks latent even under token limit.""" + prompt = ( + "The Mona Lisa is a half-length portrait painting by Italian " + "artist Leonardo da Vinci.\n\n" + "It has been described as the best known painting in the world.\n\n" + "The painting is thought to be a portrait of Lisa Gherardini.\n\n" + "According to the passage, who painted the Mona Lisa?" + ) + result = assess_transfer(prompt_tokens=200, prompt_text=prompt) + assert result.recommend_latent is False + assert result.task_classification is not None + assert result.task_classification.task_type == "comprehension" + + def test_structured_extends_token_limit(self): + """Strong structured signal allows prompts up to 1.5x limit.""" + prompt = ( + "```python\n" + "def calculate(x, y):\n" + " result = x * y + 100 / 50 - 25\n" + " return result\n" + "```\n" + "Solve: compute calculate(10, 20) step by step. " + "Find the value of 10 * 20 + 100 / 50 - 25" + ) + # 350 tokens > 300 limit, but strong structured signals + result = assess_transfer(prompt_tokens=350, prompt_text=prompt) + assert result.recommend_latent is True + + def test_structured_still_blocked_beyond_extended_limit(self): + """Even structured, blocked if way over limit (>450 for default 300).""" + prompt = "Solve: 2 + 2 = ?" + result = assess_transfer(prompt_tokens=500, prompt_text=prompt) + assert result.recommend_latent is False + + def test_task_classification_disabled(self): + config = TransferQualityConfig(use_task_classification=False) + result = assess_transfer( + prompt_tokens=200, + prompt_text="Who painted the Mona Lisa?", + config=config, + ) + assert result.task_classification is None + assert result.recommend_latent is True # under token limit + + def test_result_includes_classification(self): + result = assess_transfer( + prompt_tokens=100, + prompt_text="Solve: 2 + 2", + ) + assert isinstance(result, TransferQualityResult) + assert result.task_classification is not None + assert isinstance(result.task_classification, TaskClassification) + + def test_effective_rank_still_works_with_prompt_text(self): + """Effective rank check still works alongside task classification.""" + import torch + + config = TransferQualityConfig(check_effective_rank=True) + # Identity-like matrix = low effective rank + hidden = torch.eye(10, 64) + result = assess_transfer( + prompt_tokens=100, + prompt_text="Solve: 2 + 2", + hidden_states=hidden, + config=config, + ) + assert isinstance(result, TransferQualityResult) + + +class TestAssessTransferBackwardCompat: + """Ensure all existing behavior is preserved.""" + + def test_short_prompt_recommends_latent(self): + result = assess_transfer(prompt_tokens=200) + assert result.recommend_latent is True + + def test_long_prompt_recommends_text(self): + result = assess_transfer(prompt_tokens=1500) + assert result.recommend_latent is False + + def test_boundary_at_300(self): + assert assess_transfer(prompt_tokens=300).recommend_latent is True + assert assess_transfer(prompt_tokens=301).recommend_latent is False + + def test_custom_config_threshold(self): + config = TransferQualityConfig(max_prompt_tokens=100) + assert assess_transfer(prompt_tokens=100, config=config).recommend_latent is True + assert assess_transfer(prompt_tokens=101, config=config).recommend_latent is False + + def test_zero_tokens(self): + result = assess_transfer(prompt_tokens=0) + assert result.recommend_latent is True From 8168936644d3247040ee3e68fad1a8c537fbef07 Mon Sep 17 00:00:00 2001 From: Stanislav Date: Sat, 14 Mar 2026 02:01:49 +0000 Subject: [PATCH 03/14] Add trained C2C cross-model projector (Tier 2) Per-layer linear projections with learned sigmoid gates for cross-model latent transfer. Both source and target models frozen; only the lightweight projector trains. Inference via per-layer forward hooks that additively inject projected hidden states during prefill. New files: - rosetta/train.py: LayerProjector, TrainConfig, train_projector() - rosetta/trained_hooks.py: trained_multi_layer_hook context manager - pipeline_trained.py: GSM8K benchmark pipeline for trained mode - test_trained_projector.py: 19 tests (projector, hooks, registry, enum) Modified: - types.py: ProjectionMethod.TRAINED enum - calibrate.py: layer_weights/biases/gates fields on AVPMap - registry.py: save/load trained projection fields - huggingface.py: cross_model_method="trained" branch + _prepare_trained_injection() - run_gsm8k_2agent.py: "trained" mode with inline training Co-Authored-By: Claude Opus 4.6 --- benchmarks/gsm8k_2agent/pipeline_trained.py | 235 ++++++++++ benchmarks/gsm8k_2agent/run_gsm8k_2agent.py | 64 ++- src/avp/connectors/huggingface.py | 81 ++++ src/avp/rosetta/calibrate.py | 4 + src/avp/rosetta/registry.py | 8 + src/avp/rosetta/train.py | 463 ++++++++++++++++++++ src/avp/rosetta/trained_hooks.py | 106 +++++ src/avp/types.py | 1 + tests/test_trained_projector.py | 339 ++++++++++++++ 9 files changed, 1296 insertions(+), 5 deletions(-) create mode 100644 benchmarks/gsm8k_2agent/pipeline_trained.py create mode 100644 src/avp/rosetta/train.py create mode 100644 src/avp/rosetta/trained_hooks.py create mode 100644 tests/test_trained_projector.py diff --git a/benchmarks/gsm8k_2agent/pipeline_trained.py b/benchmarks/gsm8k_2agent/pipeline_trained.py new file mode 100644 index 0000000..d989a83 --- /dev/null +++ b/benchmarks/gsm8k_2agent/pipeline_trained.py @@ -0,0 +1,235 @@ +"""Trained projection pipeline: 2-agent chain with learned per-layer projections. + +Researcher runs on model A (latent steps → extract hidden state). +Solver runs on model B (per-layer hooks inject trained projections → generate answer). + +Requires a pre-trained AVPMap with layer_weights/layer_biases/layer_gates. +""" + +import time +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F + +from benchmarks.shared.generation import generate_text, render_prompt, tokenize_prompt +from benchmarks.shared.kv_utils import get_past_length +from benchmarks.shared.metrics import gpu_memory_tracker +from .agents import AGENTS, build_latent_prompt +from .evaluate import extract_gold, extract_gsm8k_answer, check_correct + + +def run_trained_pipeline( + conn_a: Any, + model_a: Any, + tokenizer_a: Any, + identity_a: Any, + model_b: Any, + tokenizer_b: Any, + device: str, + avp_map: Any, + question: str, + gold_solution: str, + latent_steps: int = 10, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.95, + verbose: bool = False, +) -> Dict: + """Run the 2-agent pipeline with trained per-layer projection. + + Researcher (model A): latent steps → extract hidden state + Solver (model B): per-layer hooks inject trained projections → generate text + """ + from avp.rosetta.trained_hooks import trained_multi_layer_hook + + with gpu_memory_tracker(device) as mem: + t0 = time.perf_counter() + agent_traces: List[Dict] = [] + total_prompt_tokens = 0 + total_latent_steps = 0 + total_output_tokens = 0 + + researcher = AGENTS[0] + solver = AGENTS[1] + + # --- Agent 1: Researcher on model A (latent steps) --- + messages = build_latent_prompt(researcher.role, question) + prompt_text = render_prompt(tokenizer_a, messages) + input_ids, attention_mask = tokenize_prompt(tokenizer_a, prompt_text, device) + + agent_t0 = time.perf_counter() + prompt_tokens = int(input_ids.shape[-1]) + total_prompt_tokens += prompt_tokens + total_latent_steps += latent_steps + + past_kv = conn_a.generate_latent_steps( + input_ids, latent_steps=latent_steps, attention_mask=attention_mask, + ) + + # Extract last hidden state + proj_t0 = time.perf_counter() + past_len = get_past_length(past_kv) + dummy_mask = torch.ones((1, past_len + 1), dtype=torch.long, device=device) + eos_id = tokenizer_a.eos_token_id or 0 + dummy_ids = torch.tensor([[eos_id]], device=device) + with torch.no_grad(): + out = model_a( + input_ids=dummy_ids, + attention_mask=dummy_mask, + past_key_values=past_kv, + output_hidden_states=True, + return_dict=True, + ) + src_hidden = out.hidden_states[-1][:, -1, :].float() # [1, D_src] + + # Pre-compute per-layer projections + layer_projections = [] + active_count = 0 + for i, (w, b, gate) in enumerate( + zip(avp_map.layer_weights, avp_map.layer_biases, avp_map.layer_gates) + ): + if gate < 0.01: + layer_projections.append(None) + continue + projected = F.linear( + src_hidden, w.to(device), b.to(device) + ) # [1, D_tgt] + layer_projections.append((projected, gate)) + active_count += 1 + + projection_ms = (time.perf_counter() - proj_t0) * 1000 + + agent_time_ms = (time.perf_counter() - agent_t0) * 1000 + agent_traces.append({ + "name": researcher.name, + "role": researcher.role, + "prompt_tokens": prompt_tokens, + "latent_steps": latent_steps, + "projection_ms": projection_ms, + "active_layers": active_count, + "total_layers": len(avp_map.layer_gates), + "agent_time_ms": agent_time_ms, + "output": "", + }) + + if verbose: + print(f" [{researcher.name}] latent steps={latent_steps}, " + f"projection={projection_ms:.1f}ms, " + f"active layers={active_count}/{len(avp_map.layer_gates)}") + + # Free model A KV-cache + del past_kv + if device == "cuda": + torch.cuda.empty_cache() + + # --- Agent 2: Solver on model B (trained hooks, generate) --- + messages = build_latent_prompt(solver.role, question) + prompt_text = render_prompt(tokenizer_b, messages) + input_ids, attention_mask = tokenize_prompt(tokenizer_b, prompt_text, device) + + agent_t0 = time.perf_counter() + prompt_tokens = int(input_ids.shape[-1]) + total_prompt_tokens += prompt_tokens + + with trained_multi_layer_hook(model_b, layer_projections): + text, _ = generate_text( + model_b, tokenizer_b, input_ids, attention_mask, device, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + ) + + output_encoded = tokenizer_b(text, add_special_tokens=False) + output_tokens = len(output_encoded["input_ids"]) + total_output_tokens += output_tokens + agent_time_ms = (time.perf_counter() - agent_t0) * 1000 + + agent_traces.append({ + "name": solver.name, + "role": solver.role, + "prompt_tokens": prompt_tokens, + "output_tokens": output_tokens, + "agent_time_ms": agent_time_ms, + "output": text, + }) + + if verbose: + print(f" [{solver.name}] output ({len(text)} chars): {text[:200]}...") + + wall_time = time.perf_counter() - t0 + + total_tokens = total_prompt_tokens + total_latent_steps + total_output_tokens + tokens_per_sec = total_tokens / wall_time if wall_time > 0 else 0 + + gold = extract_gold(gold_solution) + prediction = extract_gsm8k_answer(agent_traces[-1]["output"]) + correct = check_correct(prediction, gold) + + return { + "question": question, + "gold": gold, + "prediction": prediction, + "raw_output": agent_traces[-1]["output"], + "correct": correct, + "wall_time": wall_time, + "total_prompt_tokens": total_prompt_tokens, + "total_latent_steps": total_latent_steps, + "total_output_tokens": total_output_tokens, + "total_tokens": total_tokens, + "tokens_per_sec": tokens_per_sec, + "peak_memory_mb": mem["peak_memory_mb"], + "projection_overhead_ms": projection_ms, + "active_layers": active_count, + "total_layers": len(avp_map.layer_gates), + "agents": agent_traces, + "mode": "trained", + } + + +def run_trained_benchmark( + conn_a: Any, + model_a: Any, + tokenizer_a: Any, + identity_a: Any, + model_b: Any, + tokenizer_b: Any, + device: str, + avp_map: Any, + dataset: List[Dict], + latent_steps: int = 10, + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.95, + verbose: bool = False, +) -> List[Dict]: + """Run trained projection pipeline on a list of GSM8K samples.""" + results = [] + for i, sample in enumerate(dataset): + if verbose: + print(f"\n[Trained] Sample {i + 1}/{len(dataset)}: {sample['question'][:80]}...") + + result = run_trained_pipeline( + conn_a, model_a, tokenizer_a, identity_a, + model_b, tokenizer_b, device, avp_map, + question=sample["question"], + gold_solution=sample["answer"], + latent_steps=latent_steps, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + verbose=verbose, + ) + results.append(result) + + if verbose: + status = "CORRECT" if result["correct"] else "WRONG" + print(f" => {status} (pred={result['prediction']}, gold={result['gold']}, " + f"time={result['wall_time']:.1f}s)") + else: + correct = sum(1 for r in results if r["correct"]) + print(f" [Trained] {i + 1}/{len(dataset)} " + f"({correct}/{i + 1} correct, {result['wall_time']:.1f}s)", + flush=True) + + return results diff --git a/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py b/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py index 4f93d31..1d4b120 100644 --- a/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py +++ b/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py @@ -40,7 +40,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--mode", choices=["latent", "text", "direct", "rosetta", "logit_guided", - "mid_layer", "text_cross_model", "both", "all"], + "mid_layer", "trained", "text_cross_model", "both", "all"], default="all", help="Pipeline(s) to run (default: all)", ) @@ -120,12 +120,13 @@ def run_benchmark(config: dict) -> dict: run_rosetta = mode in ("rosetta", "all") run_logit_guided = mode in ("logit_guided", "all") run_mid_layer = mode in ("mid_layer", "all") + run_trained = mode in ("trained",) # not in "all" — requires training run_text_cross_model = mode in ("text_cross_model", "all") print(f"Device: {device}") print(f"Mode: {mode}") print(f"Model A: {model_name}") - if run_rosetta or run_logit_guided or run_mid_layer or run_text_cross_model: + if run_rosetta or run_logit_guided or run_mid_layer or run_trained or run_text_cross_model: print(f"Model B: {model_b_name}") print(f"Samples: {max_samples}") print(f"Latent steps: {latent_steps}") @@ -134,7 +135,8 @@ def run_benchmark(config: dict) -> dict: print(f"Seed: {seed}") print(f"Pipelines: direct={run_direct}, text={run_text}, latent={run_latent}, " f"rosetta={run_rosetta}, logit_guided={run_logit_guided}, " - f"mid_layer={run_mid_layer}, text_cross_model={run_text_cross_model}") + f"mid_layer={run_mid_layer}, trained={run_trained}, " + f"text_cross_model={run_text_cross_model}") print() dataset = load_dataset(max_samples) @@ -146,6 +148,7 @@ def run_benchmark(config: dict) -> dict: rosetta_results = None logit_guided_results = None mid_layer_results = None + trained_results = None text_cross_model_results = None if run_direct: @@ -194,7 +197,7 @@ def run_benchmark(config: dict) -> dict: # Load model B if needed for cross-model modes model_b = tokenizer_b = connector_b = identity_b = None - if run_rosetta or run_logit_guided or run_mid_layer or run_text_cross_model: + if run_rosetta or run_logit_guided or run_mid_layer or run_trained or run_text_cross_model: model_b, tokenizer_b, connector_b, identity_b = load_model(model_b_name, device) if run_text_cross_model: @@ -314,6 +317,48 @@ def run_benchmark(config: dict) -> dict: temperature=temperature, top_p=top_p, verbose=verbose, ) + if run_trained: + from benchmarks.gsm8k_2agent.pipeline_trained import run_trained_benchmark + from avp.rosetta.train import train_projector, TrainConfig + + print("\n" + "=" * 50) + print("Running TRAINED (per-layer learned projection) pipeline...") + print(f" Model A (Researcher): {model_name}") + print(f" Model B (Solver): {model_b_name}") + print("=" * 50) + + # Training phase + train_config = TrainConfig( + num_samples=config.get("train_samples", 2000), + batch_size=config.get("train_batch_size", 4), + num_epochs=config.get("train_epochs", 2), + learning_rate=config.get("train_lr", 1e-4), + ) + print(f"Training projector: {train_config.num_samples} samples, " + f"{train_config.num_epochs} epochs...") + + trained_map = train_projector( + source_model=model, + target_model=model_b, + source_tokenizer=tokenizer, + target_tokenizer=tokenizer_b, + device=device, + config=train_config, + ) + active = [i for i, g in enumerate(trained_map.layer_gates) if g > 0.01] + print(f"Training complete. Active layers: {len(active)}/{len(trained_map.layer_gates)}") + print(f"Validation score: {trained_map.validation_score:.4f}") + + set_seed(seed) + + trained_results = run_trained_benchmark( + conn_a=connector, model_a=model, tokenizer_a=tokenizer, + identity_a=identity, model_b=model_b, tokenizer_b=tokenizer_b, + device=device, avp_map=trained_map, dataset=dataset, + latent_steps=latent_steps, max_new_tokens=max_new_tokens, + temperature=temperature, top_p=top_p, verbose=verbose, + ) + # Free model B to reclaim GPU memory if model_b is not None: del model_b, tokenizer_b, connector_b, identity_b @@ -337,6 +382,8 @@ def run_benchmark(config: dict) -> dict: modes.append(("Logit-Guided", 13, logit_guided_results)) if mid_layer_results is not None: modes.append(("Mid-Layer", 13, mid_layer_results)) + if trained_results is not None: + modes.append(("Trained", 13, trained_results)) if text_cross_model_results is not None: modes.append(("Text Cross-Model", 16, text_cross_model_results)) @@ -354,6 +401,8 @@ def run_benchmark(config: dict) -> dict: available["logit_guided"] = logit_guided_results if mid_layer_results is not None: available["mid_layer"] = mid_layer_results + if trained_results is not None: + available["trained"] = trained_results if text_cross_model_results is not None: available["text_cross_model"] = text_cross_model_results agreement_data = compute_agreement(available) if len(available) > 1 else None @@ -376,7 +425,7 @@ def run_benchmark(config: dict) -> dict: "config": { "benchmark": "gsm8k_2agent", "model_a": model_name, - "model_b": model_b_name if (run_rosetta or run_logit_guided or run_mid_layer or run_text_cross_model) else None, + "model_b": model_b_name if (run_rosetta or run_logit_guided or run_mid_layer or run_trained or run_text_cross_model) else None, "device": device, "mode": mode, "max_samples": max_samples, @@ -417,6 +466,11 @@ def run_benchmark(config: dict) -> dict: "summary": compute_stats(mid_layer_results), "samples": mid_layer_results, } + if trained_results is not None: + output_data["trained"] = { + "summary": compute_stats(trained_results), + "samples": trained_results, + } if text_cross_model_results is not None: output_data["text_cross_model"] = { "summary": compute_stats(text_cross_model_results), diff --git a/src/avp/connectors/huggingface.py b/src/avp/connectors/huggingface.py index 7154521..d0090f8 100644 --- a/src/avp/connectors/huggingface.py +++ b/src/avp/connectors/huggingface.py @@ -620,6 +620,12 @@ def generate( ) # Don't use KV-cache priming — injection happens via hook context = None + elif cross_model_method == "trained": + mid_layer_hook_ctx = self._prepare_trained_injection( + source, context, _diagnostics=_diagnostics, + ) + # Don't use KV-cache priming — injection happens via per-layer hooks + context = None else: context = self._apply_rosetta_projection( source, context, _diagnostics=_diagnostics, @@ -877,6 +883,81 @@ def _prepare_mid_layer_injection( projected_hidden=projected, ) + # --- Cross-model trained projection --- + + def _prepare_trained_injection( + self, + source: "HuggingFaceConnector", + context: AVPContext, + _diagnostics: Optional[Any] = None, + ) -> Any: + """Prepare per-layer trained projection hooks for cross-model generation. + + Uses a trained AVPMap with per-layer linear projections + learned gates. + Each layer's forward hook adds the projected source hidden state + (scaled by the learned gate) to the target layer's output. + + Args: + source: Source model connector. + context: AVPContext from source's think(). + _diagnostics: Internal diagnostics object. + + Returns: + Context manager (trained_multi_layer_hook) ready to wrap generate(). + """ + import torch + + if context.last_hidden_state is None: + raise ValueError( + "Trained projection requires context with last_hidden_state. " + "Use think() to produce the context." + ) + + avp_map = self._get_or_calibrate_map(source) + + if avp_map.layer_weights is None: + raise ValueError( + "AVPMap does not contain trained per-layer projections. " + "Use train_projector() to train a projection first, or " + "use cross_model_method='rosetta' for zero-training projection." + ) + + # Source hidden state [1, D_src] or [D_src] + src_hidden = context.last_hidden_state.float() + if src_hidden.dim() == 1: + src_hidden = src_hidden.unsqueeze(0) # [1, D_src] + + # Pre-compute per-layer projections + layer_projections = [] + active_count = 0 + for i, (w, b, gate) in enumerate( + zip(avp_map.layer_weights, avp_map.layer_biases, avp_map.layer_gates) + ): + if gate < 0.01: + layer_projections.append(None) + continue + # w: [D_tgt, D_src], b: [D_tgt] + projected = torch.nn.functional.linear( + src_hidden, w.to(src_hidden.device), b.to(src_hidden.device) + ) # [1, D_tgt] + layer_projections.append((projected, gate)) + active_count += 1 + + if _diagnostics is not None: + _diagnostics.transfer_mode = "trained" + _diagnostics.projection_method = "TRAINED" + + logger.info( + "Trained injection: %d/%d active layers (gate > 0.01)", + active_count, len(avp_map.layer_gates), + ) + + from ..rosetta.trained_hooks import trained_multi_layer_hook + return trained_multi_layer_hook( + model=self.model, + layer_projections=layer_projections, + ) + # --- Cross-model rosetta projection --- def _apply_rosetta_projection( diff --git a/src/avp/rosetta/calibrate.py b/src/avp/rosetta/calibrate.py index 9710489..160ed93 100644 --- a/src/avp/rosetta/calibrate.py +++ b/src/avp/rosetta/calibrate.py @@ -42,6 +42,10 @@ class AVPMap: tgt_indices: Optional[Any] = None # LongTensor [N_shared] — target token IDs overlap_count: int = 0 overlap_ratio: float = 0.0 + # Trained translator fields (per-layer projections) + layer_weights: Optional[List[Any]] = None # List of Tensor [D_src, D_tgt] per layer + layer_biases: Optional[List[Any]] = None # List of Tensor [D_tgt] per layer + layer_gates: Optional[List[float]] = None # List of float gate values per layer def __post_init__(self) -> None: if isinstance(self.method, str): diff --git a/src/avp/rosetta/registry.py b/src/avp/rosetta/registry.py index 8ef67e8..72b4fec 100644 --- a/src/avp/rosetta/registry.py +++ b/src/avp/rosetta/registry.py @@ -56,6 +56,11 @@ def save_map(avp_map: AVPMap, map_dir: Optional[Path] = None) -> Path: "overlap_count": avp_map.overlap_count, "overlap_ratio": avp_map.overlap_ratio, } + # Trained translator fields (per-layer projections) + if avp_map.layer_weights is not None: + data["layer_weights"] = [w.cpu() for w in avp_map.layer_weights] + data["layer_biases"] = [b.cpu() for b in avp_map.layer_biases] + data["layer_gates"] = avp_map.layer_gates torch.save(data, path) return path @@ -101,6 +106,9 @@ def load_map( tgt_indices=data.get("tgt_indices"), overlap_count=data.get("overlap_count", 0), overlap_ratio=data.get("overlap_ratio", 0.0), + layer_weights=data.get("layer_weights"), + layer_biases=data.get("layer_biases"), + layer_gates=data.get("layer_gates"), ) diff --git a/src/avp/rosetta/train.py b/src/avp/rosetta/train.py new file mode 100644 index 0000000..66a637d --- /dev/null +++ b/src/avp/rosetta/train.py @@ -0,0 +1,463 @@ +"""Trained cross-model translator (C2C) for AVP. + +Trains per-layer linear projections with learned sigmoid gates to map +source model hidden states to target model activation space at every +transformer layer. Both source and target models are frozen; only the +lightweight projector trains. + +Based on: +- C2C (arxiv 2510.03215): Cross-cache fusion, +6-14% over zero-shot +- DroidSpeak: ~11% of layers are critical, gates learn which matter +- Model Stitching (2506.06609): Affine maps between layers, 2K-180K samples + +Usage:: + + from avp.rosetta.train import LayerProjector, train_projector + + projector = train_projector( + source_model=src_model, + target_model=tgt_model, + source_tokenizer=src_tok, + target_tokenizer=tgt_tok, + device="cuda", + ) + # Produces an AVPMap with method=TRAINED +""" + +import logging +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +def _require_torch(): + try: + import torch + return torch + except ImportError: + raise ImportError("torch is required for training. pip install torch") + + +@dataclass +class TrainConfig: + """Configuration for training a cross-model projector. + + Attributes: + num_samples: Number of text samples to train on. + batch_size: Training batch size. + learning_rate: Adam learning rate. + num_epochs: Number of training epochs. + gate_reg_weight: L1 regularization weight for gate sparsity. + gate_init: Initial gate logit (sigmoid(gate_init) = initial gate value). + max_seq_len: Maximum sequence length for training samples. + extraction_layer_ratio: Depth ratio for source hidden state extraction. + warmup_steps: Number of linear warmup steps. + seed: Random seed for reproducibility. + """ + + num_samples: int = 5000 + batch_size: int = 4 + learning_rate: float = 1e-4 + num_epochs: int = 2 + gate_reg_weight: float = 0.01 + gate_init: float = -3.0 # sigmoid(-3) ~ 0.05 + max_seq_len: int = 256 + extraction_layer_ratio: float = 0.75 + warmup_steps: int = 100 + seed: int = 42 + + +class LayerProjector: + """Per-layer linear projections with learned sigmoid gates. + + Maps a source hidden state to target activation space at each layer. + Gate values control injection strength per layer (initialized near zero). + """ + + def __init__(self, source_dim: int, target_dim: int, num_layers: int, config: Optional[TrainConfig] = None): + torch = _require_torch() + import torch.nn as nn + + if config is None: + config = TrainConfig() + + self.source_dim = source_dim + self.target_dim = target_dim + self.num_layers = num_layers + + # Per-layer projection: source hidden -> target hidden at layer L + self.layer_projections = nn.ModuleList([ + nn.Linear(source_dim, target_dim, bias=True) + for _ in range(num_layers) + ]) + + # Per-layer gate: sigmoid scalar controlling injection strength + self.layer_gates = nn.ParameterList([ + nn.Parameter(torch.tensor(config.gate_init)) + for _ in range(num_layers) + ]) + + # Initialize projections with small weights + for proj in self.layer_projections: + nn.init.xavier_uniform_(proj.weight, gain=0.1) + nn.init.zeros_(proj.bias) + + def parameters(self): + """Return all trainable parameters.""" + for proj in self.layer_projections: + yield from proj.parameters() + for gate in self.layer_gates: + yield gate + + def to(self, device): + """Move all parameters to device.""" + for proj in self.layer_projections: + proj.to(device) + for i, gate in enumerate(self.layer_gates): + self.layer_gates[i] = gate.to(device) + return self + + def train(self): + """Set to training mode.""" + for proj in self.layer_projections: + proj.train() + + def eval(self): + """Set to evaluation mode.""" + for proj in self.layer_projections: + proj.eval() + + def forward(self, source_hidden): + """Project source hidden state to each target layer. + + Args: + source_hidden: [B, D_src] from source model. + + Returns: + List of (projected_hidden [B, D_tgt], gate_value float) per layer. + """ + torch = _require_torch() + results = [] + for proj, gate_logit in zip(self.layer_projections, self.layer_gates): + projected = proj(source_hidden) # [B, D_tgt] + gate = torch.sigmoid(gate_logit).item() # scalar in [0, 1] + results.append((projected, gate)) + return results + + def get_active_layers(self, threshold: float = 0.01) -> List[int]: + """Return indices of layers with gate > threshold.""" + torch = _require_torch() + active = [] + for i, gate_logit in enumerate(self.layer_gates): + gate = torch.sigmoid(gate_logit).item() + if gate > threshold: + active.append(i) + return active + + def export_weights(self) -> Tuple[List[Any], List[Any], List[float]]: + """Export weights for serialization into AVPMap. + + Returns: + Tuple of (layer_weights, layer_biases, layer_gates) + where each is a list of length num_layers. + """ + torch = _require_torch() + weights = [] + biases = [] + gates = [] + for proj, gate_logit in zip(self.layer_projections, self.layer_gates): + weights.append(proj.weight.detach().cpu()) # [D_tgt, D_src] + biases.append(proj.bias.detach().cpu()) # [D_tgt] + gates.append(torch.sigmoid(gate_logit).item()) # float + return weights, biases, gates + + +def _load_training_data( + tokenizer: Any, + num_samples: int, + max_seq_len: int, + seed: int = 42, +) -> List[str]: + """Load training text samples from a dataset. + + Tries to use OpenHermes-2.5, falls back to a simple math/code/text mix. + """ + try: + from datasets import load_dataset + logger.info("Loading training data from teknium/OpenHermes-2.5...") + ds = load_dataset("teknium/OpenHermes-2.5", split="train", streaming=True) + texts = [] + for item in ds: + if len(texts) >= num_samples: + break + # Extract conversation text + convos = item.get("conversations", []) + text = " ".join(c.get("value", "") for c in convos) + if len(text) > 50: # Skip very short entries + texts.append(text[:max_seq_len * 4]) # Rough char limit + logger.info(f"Loaded {len(texts)} training samples") + return texts + except Exception as e: + logger.warning(f"Could not load OpenHermes: {e}. Using fallback data.") + # Fallback: generate simple diverse prompts + import random + rng = random.Random(seed) + templates = [ + "Solve the following math problem step by step: {a} + {b} * {c} = ?", + "Write a function that computes the {n}th Fibonacci number.", + "Explain the concept of {topic} in simple terms.", + "The quick brown fox jumps over the lazy dog. {extra}", + "In {year}, {person} made a significant discovery about {topic}.", + ] + topics = ["recursion", "neural networks", "gravity", "evolution", "democracy", + "photosynthesis", "quantum mechanics", "machine learning", "databases"] + texts = [] + for i in range(num_samples): + tmpl = rng.choice(templates) + text = tmpl.format( + a=rng.randint(1, 1000), b=rng.randint(1, 100), c=rng.randint(1, 50), + n=rng.randint(5, 50), + topic=rng.choice(topics), + extra=" ".join(rng.choices(topics, k=3)), + year=rng.randint(1900, 2025), + person=f"Dr. {rng.choice(['Smith', 'Chen', 'Patel', 'Garcia', 'Kim'])}", + ) + texts.append(text) + return texts + + +def _tokenize_batch( + texts: List[str], + tokenizer: Any, + max_seq_len: int, + device: str, +) -> Any: + """Tokenize a batch of texts.""" + torch = _require_torch() + encoded = tokenizer( + texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_seq_len, + add_special_tokens=True, + ) + return {k: v.to(device) for k, v in encoded.items()} + + +def train_projector( + source_model: Any, + target_model: Any, + source_tokenizer: Any, + target_tokenizer: Any, + device: str = "cuda", + config: Optional[TrainConfig] = None, + progress_callback: Optional[Any] = None, +) -> "AVPMap": + """Train a per-layer cross-model projector. + + Both models are frozen. Only the lightweight projector trains. + Uses MSE loss between projected source hidden states and target + reference hidden states at each layer. + + Args: + source_model: Source HuggingFace model (frozen). + target_model: Target HuggingFace model (frozen). + source_tokenizer: Source tokenizer. + target_tokenizer: Target tokenizer (used for training data). + device: Training device. + config: Training configuration. + progress_callback: Optional callback(step, total_steps, loss) for progress. + + Returns: + AVPMap with method=TRAINED and per-layer projection data. + """ + torch = _require_torch() + import torch.nn.functional as F + from ..handshake import compute_model_hash, extract_model_identity + + if config is None: + config = TrainConfig() + + # Set seed + torch.manual_seed(config.seed) + + # Get model info + src_identity = extract_model_identity(source_model, source_tokenizer) + tgt_identity = extract_model_identity(target_model, target_tokenizer) + src_hash = compute_model_hash(src_identity) + tgt_hash = compute_model_hash(tgt_identity) + + source_dim = src_identity.hidden_dim + target_dim = tgt_identity.hidden_dim + target_num_layers = tgt_identity.num_layers + + logger.info( + "Training projector: %s (%dd) -> %s (%dd, %d layers)", + src_identity.model_id, source_dim, + tgt_identity.model_id, target_dim, target_num_layers, + ) + + # Create projector + projector = LayerProjector(source_dim, target_dim, target_num_layers, config) + projector.to(device) + projector.train() + + # Freeze both models + source_model.eval() + target_model.eval() + for p in source_model.parameters(): + p.requires_grad_(False) + for p in target_model.parameters(): + p.requires_grad_(False) + + # Load training data + texts = _load_training_data( + target_tokenizer, config.num_samples, config.max_seq_len, config.seed + ) + + # Optimizer + optimizer = torch.optim.AdamW( + projector.parameters(), + lr=config.learning_rate, + weight_decay=0.01, + ) + + # Training loop + num_batches = math.ceil(len(texts) / config.batch_size) + total_steps = num_batches * config.num_epochs + step = 0 + best_loss = float("inf") + + for epoch in range(config.num_epochs): + epoch_loss = 0.0 + epoch_steps = 0 + + for batch_start in range(0, len(texts), config.batch_size): + batch_texts = texts[batch_start:batch_start + config.batch_size] + + # Tokenize with target tokenizer (both models process same text) + try: + tgt_encoded = _tokenize_batch(batch_texts, target_tokenizer, config.max_seq_len, device) + src_encoded = _tokenize_batch(batch_texts, source_tokenizer, config.max_seq_len, device) + except Exception: + continue + + # Source forward pass -> extract hidden state + with torch.no_grad(): + src_out = source_model( + **src_encoded, + output_hidden_states=True, + return_dict=True, + ) + # Extract from extraction layer ratio + src_hidden_states = src_out.hidden_states + extract_idx = int(len(src_hidden_states) * config.extraction_layer_ratio) + extract_idx = min(extract_idx, len(src_hidden_states) - 1) + src_hidden = src_hidden_states[extract_idx][:, -1, :] # [B, D_src] + + # Target forward pass -> reference hidden states at each layer + with torch.no_grad(): + tgt_out = target_model( + **tgt_encoded, + output_hidden_states=True, + return_dict=True, + ) + tgt_hidden_states = tgt_out.hidden_states # tuple of [B, seq, D_tgt] + + # Project and compute loss + projections = projector.forward(src_hidden.float()) + + loss = torch.tensor(0.0, device=device, requires_grad=True) + for layer_idx, (proj_h, gate) in enumerate(projections): + if gate < 0.001: + continue # Skip nearly-zero gates for efficiency + + # Reference: target hidden state at this layer's last token + # tgt_hidden_states has (num_layers + 1) entries + ref_idx = min(layer_idx + 1, len(tgt_hidden_states) - 1) + ref_h = tgt_hidden_states[ref_idx][:, -1, :].float() # [B, D_tgt] + + # MSE loss weighted by gate + layer_loss = F.mse_loss(proj_h, ref_h) + loss = loss + gate * layer_loss + + # Gate sparsity regularization (encourage most gates to be ~0) + gate_sum = sum( + torch.sigmoid(g) for g in projector.layer_gates + ) / target_num_layers + loss = loss + config.gate_reg_weight * gate_sum + + # Backward pass + optimizer.zero_grad() + loss.backward() + + # Gradient clipping + torch.nn.utils.clip_grad_norm_(projector.parameters(), 1.0) + + optimizer.step() + + step += 1 + loss_val = loss.item() + epoch_loss += loss_val + epoch_steps += 1 + + if step % 50 == 0: + active = projector.get_active_layers() + logger.info( + "Step %d/%d, loss=%.4f, active_layers=%d/%d", + step, total_steps, loss_val, len(active), target_num_layers, + ) + + if progress_callback is not None: + progress_callback(step, total_steps, loss_val) + + # Free intermediate tensors + del src_out, tgt_out, projections, loss + if device == "cuda": + torch.cuda.empty_cache() + + avg_loss = epoch_loss / max(epoch_steps, 1) + logger.info("Epoch %d/%d complete, avg_loss=%.4f", epoch + 1, config.num_epochs, avg_loss) + + if avg_loss < best_loss: + best_loss = avg_loss + + # Export to AVPMap + projector.eval() + layer_weights, layer_biases, layer_gates = projector.export_weights() + + active = projector.get_active_layers() + logger.info( + "Training complete. Active layers: %d/%d (gates > 0.01): %s", + len(active), target_num_layers, active, + ) + + from .calibrate import AVPMap + from ..types import ProjectionMethod + + # Compute target norm from target model embeddings + tgt_embed = target_model.get_input_embeddings() + target_norm = tgt_embed.weight.float().norm(dim=-1).mean().detach().cpu() + + avp_map = AVPMap( + source_model_id=src_identity.model_id, + source_hash=src_hash, + source_dim=source_dim, + target_model_id=tgt_identity.model_id, + target_hash=tgt_hash, + target_dim=target_dim, + w_map=torch.zeros(1), # Placeholder (projections stored in layer_weights) + bias=None, + target_norm=target_norm, + method=ProjectionMethod.TRAINED, + anchor_count=config.num_samples, + validation_score=1.0 - best_loss, # Higher is better + layer_weights=layer_weights, + layer_biases=layer_biases, + layer_gates=layer_gates, + ) + + return avp_map diff --git a/src/avp/rosetta/trained_hooks.py b/src/avp/rosetta/trained_hooks.py new file mode 100644 index 0000000..a131bfa --- /dev/null +++ b/src/avp/rosetta/trained_hooks.py @@ -0,0 +1,106 @@ +"""Forward hooks for trained per-layer cross-model projection. + +Installs per-layer forward hooks that additively inject projected source +hidden states (scaled by learned gates) during the first forward pass (prefill). +Hooks fire once and are removed after the context manager exits. +""" + +import logging +from contextlib import contextmanager +from typing import Any, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +def _get_decoder_layers(model: Any): + """Get the list of decoder layers from a HuggingFace model.""" + inner = getattr(model, "model", None) + if inner is not None: + layers = getattr(inner, "layers", None) + if layers is not None: + return layers + + transformer = getattr(model, "transformer", None) + if transformer is not None: + h = getattr(transformer, "h", None) + if h is not None: + return h + + raise AttributeError( + f"Cannot find decoder layers in model {type(model).__name__}. " + "Expected model.model.layers or model.transformer.h" + ) + + +@contextmanager +def trained_multi_layer_hook( + model: Any, + layer_projections: List[Optional[Tuple[Any, float]]], +): + """Context manager that installs per-layer forward hooks for trained projection. + + Each hook adds the pre-computed projected hidden state (scaled by gate) + to the target layer's last-token hidden state. Hooks fire only on the + first forward pass (prefill), not during autoregressive generation. + + Args: + model: HuggingFace model to hook into. + layer_projections: List of length num_layers. Each entry is either: + - None (gate < threshold, skip this layer) + - (projected_hidden [1, D_tgt], gate_value float) + + Yields: + None. Hooks are active during the context. + """ + import torch + + layers = _get_decoder_layers(model) + handles = [] + fired_flags = [] + + for layer_idx, proj_data in enumerate(layer_projections): + if proj_data is None: + continue + if layer_idx >= len(layers): + break + + projected, gate = proj_data + target_layer = layers[layer_idx] + fired = [False] + fired_flags.append(fired) + + def make_hook(proj_h, g, f): + def hook_fn(module, input, output): + if f[0]: + return output + + f[0] = True + + if isinstance(output, tuple): + hidden = output[0] # [B, seq_len, D] + else: + hidden = output + + # Add projected source hidden state (scaled by gate) to last token + injection = proj_h.to(device=hidden.device, dtype=hidden.dtype) + if injection.dim() == 2: + injection = injection.squeeze(0) # [1, D] -> [D] + + modified = hidden.clone() + modified[:, -1, :] = modified[:, -1, :] + g * injection + + if isinstance(output, tuple): + return (modified,) + output[1:] + return modified + + return hook_fn + + hook = make_hook(projected, gate, fired) + handle = target_layer.register_forward_hook(hook) + handles.append(handle) + + try: + yield + finally: + for handle in handles: + handle.remove() diff --git a/src/avp/types.py b/src/avp/types.py index c156253..3ac747d 100644 --- a/src/avp/types.py +++ b/src/avp/types.py @@ -60,6 +60,7 @@ class ProjectionMethod(enum.Enum): PROCRUSTES = "procrustes" VOCAB_MEDIATED = "vocab_mediated" VOCAB_OVERLAP = "vocab_overlap" + TRAINED = "trained" class DataType(enum.IntEnum): diff --git a/tests/test_trained_projector.py b/tests/test_trained_projector.py new file mode 100644 index 0000000..64dc4ea --- /dev/null +++ b/tests/test_trained_projector.py @@ -0,0 +1,339 @@ +"""Tests for the trained cross-model projector (C2C). + +Tests LayerProjector, TrainConfig, trained_hooks, and registry serialization +of trained AVPMap fields. +""" + +import pytest +import torch +import torch.nn as nn + +from avp.rosetta.train import LayerProjector, TrainConfig +from avp.rosetta.trained_hooks import trained_multi_layer_hook, _get_decoder_layers +from avp.rosetta.calibrate import AVPMap +from avp.types import ProjectionMethod + + +# --- TrainConfig tests --- + + +class TestTrainConfig: + def test_defaults(self): + config = TrainConfig() + assert config.num_samples == 5000 + assert config.batch_size == 4 + assert config.learning_rate == 1e-4 + assert config.num_epochs == 2 + assert config.gate_reg_weight == 0.01 + assert config.gate_init == -3.0 + assert config.max_seq_len == 256 + assert config.warmup_steps == 100 + assert config.seed == 42 + + def test_custom_values(self): + config = TrainConfig(num_samples=100, batch_size=8, learning_rate=1e-3) + assert config.num_samples == 100 + assert config.batch_size == 8 + assert config.learning_rate == 1e-3 + + +# --- LayerProjector tests --- + + +class TestLayerProjector: + def test_init_shapes(self): + proj = LayerProjector(source_dim=128, target_dim=64, num_layers=4) + assert len(proj.layer_projections) == 4 + assert len(proj.layer_gates) == 4 + # Each projection: [D_src] -> [D_tgt] + assert proj.layer_projections[0].in_features == 128 + assert proj.layer_projections[0].out_features == 64 + + def test_gate_init_near_zero(self): + """Gates should be initialized near zero (sigmoid(-3) ~ 0.05).""" + proj = LayerProjector(source_dim=32, target_dim=32, num_layers=8) + for gate_logit in proj.layer_gates: + gate_val = torch.sigmoid(gate_logit).item() + assert gate_val < 0.1, f"Gate initialized too high: {gate_val}" + + def test_forward_output_shapes(self): + proj = LayerProjector(source_dim=64, target_dim=32, num_layers=3) + src = torch.randn(2, 64) # batch=2 + results = proj.forward(src) + assert len(results) == 3 + for projected, gate in results: + assert projected.shape == (2, 32) + assert 0 <= gate <= 1 + + def test_get_active_layers_initial(self): + """Initially, all gates near zero → no active layers at default threshold.""" + proj = LayerProjector(source_dim=32, target_dim=32, num_layers=8) + active = proj.get_active_layers(threshold=0.1) + assert len(active) == 0 + + def test_get_active_layers_after_setting_gate(self): + proj = LayerProjector(source_dim=32, target_dim=32, num_layers=4) + # Force gate 1 and 3 open + with torch.no_grad(): + proj.layer_gates[1].fill_(5.0) # sigmoid(5) ~ 0.993 + proj.layer_gates[3].fill_(3.0) # sigmoid(3) ~ 0.953 + active = proj.get_active_layers(threshold=0.5) + assert active == [1, 3] + + def test_export_weights(self): + proj = LayerProjector(source_dim=64, target_dim=32, num_layers=3) + weights, biases, gates = proj.export_weights() + assert len(weights) == 3 + assert len(biases) == 3 + assert len(gates) == 3 + for w in weights: + assert w.shape == (32, 64) # nn.Linear stores [out, in] + for b in biases: + assert b.shape == (32,) + for g in gates: + assert isinstance(g, float) + assert 0 <= g <= 1 + + def test_parameters_count(self): + proj = LayerProjector(source_dim=64, target_dim=32, num_layers=3) + params = list(proj.parameters()) + # 3 layers × (weight + bias) + 3 gates = 9 parameters + assert len(params) == 9 + + def test_to_device(self): + proj = LayerProjector(source_dim=32, target_dim=16, num_layers=2) + proj.to("cpu") # should work without error + # Verify parameters are on cpu + for p in proj.parameters(): + assert str(p.device) == "cpu" + + +# --- Trained hooks tests --- + + +class FakeDecoderLayer(nn.Module): + """Minimal decoder layer for hook testing.""" + def __init__(self, dim): + super().__init__() + self.linear = nn.Linear(dim, dim) + + def forward(self, x): + return (self.linear(x),) + + +class FakeModel(nn.Module): + """Minimal model with decoder layers for hook testing.""" + def __init__(self, dim, num_layers): + super().__init__() + self.model = nn.Module() + self.model.layers = nn.ModuleList([ + FakeDecoderLayer(dim) for _ in range(num_layers) + ]) + + +class TestTrainedMultiLayerHook: + def test_hook_adds_projection(self): + dim = 16 + model = FakeModel(dim=dim, num_layers=4) + model.eval() + + projected = torch.ones(1, dim) * 10.0 + gate = 0.5 + + layer_projections = [None, (projected, gate), None, None] + + x = torch.randn(1, 5, dim) + + with trained_multi_layer_hook(model, layer_projections): + # Forward through layer 1 (the hooked layer) + layers = _get_decoder_layers(model) + output = layers[1](x) + + # The hook should have modified the last token + modified_hidden = output[0] + # Verify last token was modified (added gate * projection) + # Original output + 0.5 * 10.0 = original + 5.0 + with torch.no_grad(): + original_output = model.model.layers[1](x) + original_last = original_output[0][:, -1, :] + modified_last = modified_hidden[:, -1, :] + diff = (modified_last - original_last).abs().mean().item() + assert diff > 1.0, f"Hook didn't modify output enough: diff={diff}" + + def test_hook_fires_once(self): + dim = 8 + model = FakeModel(dim=dim, num_layers=2) + model.eval() + + projected = torch.ones(1, dim) * 100.0 + layer_projections = [(projected, 1.0), None] + + x = torch.randn(1, 3, dim) + + with trained_multi_layer_hook(model, layer_projections): + layers = _get_decoder_layers(model) + # First forward — hook fires + out1 = layers[0](x) + # Second forward — hook should NOT fire + out2 = layers[0](x) + + # out1 should differ from out2 (hook only fires once) + # Actually both outputs are from the same input x, but the first + # has the injection and the second doesn't + last1 = out1[0][:, -1, :].detach() + last2 = out2[0][:, -1, :].detach() + diff = (last1 - last2).abs().mean().item() + assert diff > 1.0, f"Hook fired more than once or didn't fire: diff={diff}" + + def test_hook_cleanup(self): + dim = 8 + model = FakeModel(dim=dim, num_layers=2) + + projected = torch.ones(1, dim) + layer_projections = [(projected, 0.5), (projected, 0.3)] + + with trained_multi_layer_hook(model, layer_projections): + pass + + # Verify hooks are removed + layers = _get_decoder_layers(model) + for layer in layers: + assert len(layer._forward_hooks) == 0 + + def test_skip_none_layers(self): + dim = 8 + model = FakeModel(dim=dim, num_layers=4) + + # Only layer 2 active + layer_projections = [None, None, (torch.ones(1, dim), 0.5), None] + + with trained_multi_layer_hook(model, layer_projections): + layers = _get_decoder_layers(model) + # Layers 0, 1, 3 should have no hooks + assert len(layers[0]._forward_hooks) == 0 + assert len(layers[1]._forward_hooks) == 0 + assert len(layers[2]._forward_hooks) == 1 + assert len(layers[3]._forward_hooks) == 0 + + def test_exception_cleanup(self): + dim = 8 + model = FakeModel(dim=dim, num_layers=2) + + projected = torch.ones(1, dim) + layer_projections = [(projected, 0.5), (projected, 0.3)] + + with pytest.raises(RuntimeError): + with trained_multi_layer_hook(model, layer_projections): + raise RuntimeError("test error") + + # Hooks should still be cleaned up + layers = _get_decoder_layers(model) + for layer in layers: + assert len(layer._forward_hooks) == 0 + + +# --- Registry serialization tests --- + + +class TestRegistrySerialization: + def test_save_load_trained_map(self, tmp_path): + """Round-trip save/load of AVPMap with trained projection fields.""" + from avp.rosetta.registry import save_map, load_map + + num_layers = 3 + src_dim = 64 + tgt_dim = 32 + + avp_map = AVPMap( + source_model_id="test/source", + source_hash="abc123" * 8, + source_dim=src_dim, + target_model_id="test/target", + target_hash="def456" * 8, + target_dim=tgt_dim, + w_map=torch.zeros(1), + bias=None, + target_norm=torch.tensor(1.0), + method=ProjectionMethod.TRAINED, + anchor_count=5000, + validation_score=0.85, + layer_weights=[torch.randn(tgt_dim, src_dim) for _ in range(num_layers)], + layer_biases=[torch.randn(tgt_dim) for _ in range(num_layers)], + layer_gates=[0.01, 0.95, 0.03], + ) + + path = save_map(avp_map, map_dir=tmp_path) + assert path.exists() + + loaded = load_map( + avp_map.source_hash, avp_map.target_hash, + device="cpu", map_dir=tmp_path, + ) + assert loaded is not None + assert loaded.method == ProjectionMethod.TRAINED + assert len(loaded.layer_weights) == num_layers + assert len(loaded.layer_biases) == num_layers + assert loaded.layer_gates == [0.01, 0.95, 0.03] + + # Verify tensor shapes preserved + for w in loaded.layer_weights: + assert w.shape == (tgt_dim, src_dim) + for b in loaded.layer_biases: + assert b.shape == (tgt_dim,) + + def test_save_load_non_trained_map_no_layer_fields(self, tmp_path): + """Non-trained maps should have None for layer fields.""" + from avp.rosetta.registry import save_map, load_map + + avp_map = AVPMap( + source_model_id="test/source", + source_hash="aaa111" * 8, + source_dim=64, + target_model_id="test/target", + target_hash="bbb222" * 8, + target_dim=32, + w_map=torch.randn(64, 32), + bias=None, + target_norm=torch.tensor(1.0), + method=ProjectionMethod.RIDGE, + anchor_count=50, + validation_score=0.5, + ) + + save_map(avp_map, map_dir=tmp_path) + loaded = load_map( + avp_map.source_hash, avp_map.target_hash, + device="cpu", map_dir=tmp_path, + ) + assert loaded is not None + assert loaded.layer_weights is None + assert loaded.layer_biases is None + assert loaded.layer_gates is None + + +# --- AVPMap TRAINED method test --- + + +class TestAVPMapTrained: + def test_trained_method_enum(self): + assert ProjectionMethod.TRAINED.value == "trained" + + def test_avpmap_with_trained_fields(self): + avp_map = AVPMap( + source_model_id="src", + source_hash="h1", + source_dim=128, + target_model_id="tgt", + target_hash="h2", + target_dim=64, + w_map=torch.zeros(1), + bias=None, + target_norm=torch.tensor(1.0), + method="trained", # string should convert + anchor_count=1000, + validation_score=0.9, + layer_weights=[torch.randn(64, 128)], + layer_biases=[torch.randn(64)], + layer_gates=[0.5], + ) + assert avp_map.method == ProjectionMethod.TRAINED From 4c6626af906d75061b543559bdc2d752991db9f7 Mon Sep 17 00:00:00 2001 From: Stanislav Date: Sat, 14 Mar 2026 02:04:24 +0000 Subject: [PATCH 04/14] Fix gradient flow through gate logits in training loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The forward() method was calling .item() on sigmoid gates, detaching them from the computation graph. This meant gate logits only received gradients from L1 regularization, not from the MSE loss — so gates couldn't learn which layers are important from the training signal. Fix: add return_gate_tensors parameter. Training uses True (tensor gates for gradient flow), inference uses False (float gates for speed). Co-Authored-By: Claude Opus 4.6 --- src/avp/rosetta/train.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/avp/rosetta/train.py b/src/avp/rosetta/train.py index 66a637d..569374b 100644 --- a/src/avp/rosetta/train.py +++ b/src/avp/rosetta/train.py @@ -129,21 +129,27 @@ def eval(self): for proj in self.layer_projections: proj.eval() - def forward(self, source_hidden): + def forward(self, source_hidden, return_gate_tensors: bool = False): """Project source hidden state to each target layer. Args: source_hidden: [B, D_src] from source model. + return_gate_tensors: If True, return gate as tensor (for training + gradient flow). If False, return gate as float (for inference). Returns: - List of (projected_hidden [B, D_tgt], gate_value float) per layer. + List of (projected_hidden [B, D_tgt], gate_value) per layer. + gate_value is a float (inference) or tensor (training). """ torch = _require_torch() results = [] for proj, gate_logit in zip(self.layer_projections, self.layer_gates): projected = proj(source_hidden) # [B, D_tgt] - gate = torch.sigmoid(gate_logit).item() # scalar in [0, 1] - results.append((projected, gate)) + gate_tensor = torch.sigmoid(gate_logit) # scalar tensor in [0, 1] + if return_gate_tensors: + results.append((projected, gate_tensor)) + else: + results.append((projected, gate_tensor.item())) return results def get_active_layers(self, threshold: float = 0.01) -> List[int]: @@ -367,12 +373,13 @@ def train_projector( ) tgt_hidden_states = tgt_out.hidden_states # tuple of [B, seq, D_tgt] - # Project and compute loss - projections = projector.forward(src_hidden.float()) + # Project and compute loss (return_gate_tensors=True for gradient flow) + projections = projector.forward(src_hidden.float(), return_gate_tensors=True) loss = torch.tensor(0.0, device=device, requires_grad=True) for layer_idx, (proj_h, gate) in enumerate(projections): - if gate < 0.001: + gate_val = gate.item() + if gate_val < 0.001: continue # Skip nearly-zero gates for efficiency # Reference: target hidden state at this layer's last token @@ -380,7 +387,7 @@ def train_projector( ref_idx = min(layer_idx + 1, len(tgt_hidden_states) - 1) ref_h = tgt_hidden_states[ref_idx][:, -1, :].float() # [B, D_tgt] - # MSE loss weighted by gate + # MSE loss weighted by gate (gate is tensor for gradient flow) layer_loss = F.mse_loss(proj_h, ref_h) loss = loss + gate * layer_loss From afbb0089430f775f9631c3a0e5d096ac3a7c35e8 Mon Sep 17 00:00:00 2001 From: Stanislav Date: Sat, 14 Mar 2026 02:05:32 +0000 Subject: [PATCH 05/14] Fix compute_model_hash call in train_projector Was passing ModelIdentity object instead of config dict. Now uses model.config.to_dict() like all other call sites. Co-Authored-By: Claude Opus 4.6 --- src/avp/rosetta/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/avp/rosetta/train.py b/src/avp/rosetta/train.py index 569374b..ab9b35d 100644 --- a/src/avp/rosetta/train.py +++ b/src/avp/rosetta/train.py @@ -293,8 +293,8 @@ def train_projector( # Get model info src_identity = extract_model_identity(source_model, source_tokenizer) tgt_identity = extract_model_identity(target_model, target_tokenizer) - src_hash = compute_model_hash(src_identity) - tgt_hash = compute_model_hash(tgt_identity) + src_hash = compute_model_hash(source_model.config.to_dict()) + tgt_hash = compute_model_hash(target_model.config.to_dict()) source_dim = src_identity.hidden_dim target_dim = tgt_identity.hidden_dim From 37a588d090768d0c54b6b052e212863e4d083742 Mon Sep 17 00:00:00 2001 From: Stanislav Date: Sat, 14 Mar 2026 03:02:48 +0000 Subject: [PATCH 06/14] Add NTP loss as primary training objective for C2C projector Research shows MSE-only training optimizes geometric alignment but not downstream generation quality. This adds cross-entropy (NTP) loss through the hooked target model as the primary loss, with MSE as auxiliary (0.1 weight). Also fixes MSE to use unhooked reference hidden states (avoiding circular reference), lowers gate_init to -5.0 for less initial corruption. Co-Authored-By: Claude Opus 4.6 --- src/avp/rosetta/train.py | 89 +++++++++++++++++++++++++-------- tests/test_trained_projector.py | 4 +- 2 files changed, 70 insertions(+), 23 deletions(-) diff --git a/src/avp/rosetta/train.py b/src/avp/rosetta/train.py index ab9b35d..92fc562 100644 --- a/src/avp/rosetta/train.py +++ b/src/avp/rosetta/train.py @@ -62,11 +62,13 @@ class TrainConfig: learning_rate: float = 1e-4 num_epochs: int = 2 gate_reg_weight: float = 0.01 - gate_init: float = -3.0 # sigmoid(-3) ~ 0.05 + gate_init: float = -5.0 # sigmoid(-5) ~ 0.007 — near zero to avoid corrupting target max_seq_len: int = 256 extraction_layer_ratio: float = 0.75 warmup_steps: int = 100 seed: int = 42 + mse_aux_weight: float = 0.1 # Weight for MSE auxiliary loss (0 = NTP only) + use_ntp_loss: bool = True # Use next-token prediction as primary loss class LayerProjector: @@ -364,32 +366,74 @@ def train_projector( extract_idx = min(extract_idx, len(src_hidden_states) - 1) src_hidden = src_hidden_states[extract_idx][:, -1, :] # [B, D_src] - # Target forward pass -> reference hidden states at each layer - with torch.no_grad(): - tgt_out = target_model( - **tgt_encoded, - output_hidden_states=True, - return_dict=True, - ) - tgt_hidden_states = tgt_out.hidden_states # tuple of [B, seq, D_tgt] - - # Project and compute loss (return_gate_tensors=True for gradient flow) + # Project source hidden state through all layers projections = projector.forward(src_hidden.float(), return_gate_tensors=True) - loss = torch.tensor(0.0, device=device, requires_grad=True) - for layer_idx, (proj_h, gate) in enumerate(projections): + # Build per-layer projections for hooks + layer_proj_for_hooks = [] + for proj_h, gate in projections: gate_val = gate.item() if gate_val < 0.001: - continue # Skip nearly-zero gates for efficiency + layer_proj_for_hooks.append(None) + else: + layer_proj_for_hooks.append((proj_h, gate)) - # Reference: target hidden state at this layer's last token - # tgt_hidden_states has (num_layers + 1) entries - ref_idx = min(layer_idx + 1, len(tgt_hidden_states) - 1) - ref_h = tgt_hidden_states[ref_idx][:, -1, :].float() # [B, D_tgt] + loss = torch.tensor(0.0, device=device, requires_grad=True) - # MSE loss weighted by gate (gate is tensor for gradient flow) - layer_loss = F.mse_loss(proj_h, ref_h) - loss = loss + gate * layer_loss + # Reference forward pass (unhooked, no grad) for MSE auxiliary + ref_hidden_states = None + if config.mse_aux_weight > 0: + with torch.no_grad(): + ref_out = target_model( + **tgt_encoded, + output_hidden_states=True, + return_dict=True, + ) + ref_hidden_states = ref_out.hidden_states + + # Primary loss: NTP through target model with injected projections + if config.use_ntp_loss: + from .trained_hooks import trained_multi_layer_hook + tgt_input_ids = tgt_encoded["input_ids"] + labels = tgt_input_ids[:, 1:].contiguous() # shift right + + with trained_multi_layer_hook(target_model, layer_proj_for_hooks): + tgt_out = target_model( + **tgt_encoded, + output_hidden_states=True, + return_dict=True, + ) + logits = tgt_out.logits[:, :-1, :].contiguous() # [B, seq-1, vocab] + ntp_loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + labels.view(-1), + ignore_index=target_tokenizer.pad_token_id or -100, + ) + loss = loss + ntp_loss + elif config.mse_aux_weight > 0: + pass # ref_hidden_states already computed above + else: + # Neither NTP nor MSE — nothing to train on + with torch.no_grad(): + ref_out = target_model( + **tgt_encoded, + output_hidden_states=True, + return_dict=True, + ) + ref_hidden_states = ref_out.hidden_states + + # Auxiliary loss: MSE between projected and unhooked reference hidden states + if config.mse_aux_weight > 0 and ref_hidden_states is not None: + mse_loss = torch.tensor(0.0, device=device, requires_grad=True) + for layer_idx, (proj_h, gate) in enumerate(projections): + gate_val = gate.item() + if gate_val < 0.001: + continue + ref_idx = min(layer_idx + 1, len(ref_hidden_states) - 1) + ref_h = ref_hidden_states[ref_idx][:, -1, :].float().detach() + layer_loss = F.mse_loss(proj_h, ref_h) + mse_loss = mse_loss + gate * layer_loss + loss = loss + config.mse_aux_weight * mse_loss # Gate sparsity regularization (encourage most gates to be ~0) gate_sum = sum( @@ -422,7 +466,8 @@ def train_projector( progress_callback(step, total_steps, loss_val) # Free intermediate tensors - del src_out, tgt_out, projections, loss + del src_out, projections, loss + ref_hidden_states = None if device == "cuda": torch.cuda.empty_cache() diff --git a/tests/test_trained_projector.py b/tests/test_trained_projector.py index 64dc4ea..3a69063 100644 --- a/tests/test_trained_projector.py +++ b/tests/test_trained_projector.py @@ -25,10 +25,12 @@ def test_defaults(self): assert config.learning_rate == 1e-4 assert config.num_epochs == 2 assert config.gate_reg_weight == 0.01 - assert config.gate_init == -3.0 + assert config.gate_init == -5.0 assert config.max_seq_len == 256 assert config.warmup_steps == 100 assert config.seed == 42 + assert config.mse_aux_weight == 0.1 + assert config.use_ntp_loss is True def test_custom_values(self): config = TrainConfig(num_samples=100, batch_size=8, learning_rate=1e-3) From eb5dbd22f85d973a5b7527ff458837f1589ba7ec Mon Sep 17 00:00:00 2001 From: Stanislav Date: Sat, 14 Mar 2026 03:04:31 +0000 Subject: [PATCH 07/14] Skip hidden states storage in hooked NTP forward pass The hooked forward pass only needs logits for NTP loss. Hidden states for MSE auxiliary come from the separate unhooked reference pass. Setting output_hidden_states=False saves activation memory. Co-Authored-By: Claude Opus 4.6 --- src/avp/rosetta/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/avp/rosetta/train.py b/src/avp/rosetta/train.py index 92fc562..6bfdcc0 100644 --- a/src/avp/rosetta/train.py +++ b/src/avp/rosetta/train.py @@ -400,7 +400,7 @@ def train_projector( with trained_multi_layer_hook(target_model, layer_proj_for_hooks): tgt_out = target_model( **tgt_encoded, - output_hidden_states=True, + output_hidden_states=False, return_dict=True, ) logits = tgt_out.logits[:, :-1, :].contiguous() # [B, seq-1, vocab] From 1f4c4a137dfdd97e38d7c6745391b03cb584e1fb Mon Sep 17 00:00:00 2001 From: Stanislav Date: Sat, 14 Mar 2026 04:52:03 +0000 Subject: [PATCH 08/14] Add gate_init and gate_reg_weight passthrough to trained benchmark config Exp 3 (NTP loss) failed due to cold-start gate collapse: gate_init=-5.0 combined with L1 regularization pushed all 28 gates to zero. Now exposing these hyperparameters so experiments can test warm-gate NTP configurations. Co-Authored-By: Claude Opus 4.6 --- benchmarks/gsm8k_2agent/run_gsm8k_2agent.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py b/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py index 1d4b120..93d7352 100644 --- a/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py +++ b/benchmarks/gsm8k_2agent/run_gsm8k_2agent.py @@ -333,6 +333,8 @@ def run_benchmark(config: dict) -> dict: batch_size=config.get("train_batch_size", 4), num_epochs=config.get("train_epochs", 2), learning_rate=config.get("train_lr", 1e-4), + gate_init=config.get("train_gate_init", -5.0), + gate_reg_weight=config.get("train_gate_reg", 0.01), ) print(f"Training projector: {train_config.num_samples} samples, " f"{train_config.num_epochs} epochs...") From 867eb21ea7e92d4e32b676272c645e2a34ca9324 Mon Sep 17 00:00:00 2001 From: Stanislav Date: Sat, 14 Mar 2026 06:44:12 +0000 Subject: [PATCH 09/14] Set MSE-only as default training config (gate_init=-3.0, use_ntp_loss=False) 4 experiments showed MSE-only with gate_init=-3.0 matches NTP loss at 76% GSM8K cross-family accuracy (+6pp over rosetta) while requiring half the training compute. NTP with cold gates (-5.0) causes gate collapse to 0/28. Co-Authored-By: Claude Opus 4.6 --- src/avp/rosetta/train.py | 4 ++-- tests/test_trained_projector.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/avp/rosetta/train.py b/src/avp/rosetta/train.py index 6bfdcc0..7cf3c2b 100644 --- a/src/avp/rosetta/train.py +++ b/src/avp/rosetta/train.py @@ -62,13 +62,13 @@ class TrainConfig: learning_rate: float = 1e-4 num_epochs: int = 2 gate_reg_weight: float = 0.01 - gate_init: float = -5.0 # sigmoid(-5) ~ 0.007 — near zero to avoid corrupting target + gate_init: float = -3.0 # sigmoid(-3) ~ 0.05 — warm enough for gradient signal max_seq_len: int = 256 extraction_layer_ratio: float = 0.75 warmup_steps: int = 100 seed: int = 42 mse_aux_weight: float = 0.1 # Weight for MSE auxiliary loss (0 = NTP only) - use_ntp_loss: bool = True # Use next-token prediction as primary loss + use_ntp_loss: bool = False # MSE-only is sufficient (NTP adds no accuracy, doubles compute) class LayerProjector: diff --git a/tests/test_trained_projector.py b/tests/test_trained_projector.py index 3a69063..5f495d5 100644 --- a/tests/test_trained_projector.py +++ b/tests/test_trained_projector.py @@ -25,12 +25,12 @@ def test_defaults(self): assert config.learning_rate == 1e-4 assert config.num_epochs == 2 assert config.gate_reg_weight == 0.01 - assert config.gate_init == -5.0 + assert config.gate_init == -3.0 assert config.max_seq_len == 256 assert config.warmup_steps == 100 assert config.seed == 42 assert config.mse_aux_weight == 0.1 - assert config.use_ntp_loss is True + assert config.use_ntp_loss is False def test_custom_values(self): config = TrainConfig(num_samples=100, batch_size=8, learning_rate=1e-3) From faa1aa9c6caa943da6c971b61d5e6e03bbdddfb4 Mon Sep 17 00:00:00 2001 From: Stanislav Date: Sat, 14 Mar 2026 19:59:53 +0000 Subject: [PATCH 10/14] Add hybrid latent + selective text mode to HotpotQA rosetta pipeline Extract top-K tokens by attention weight from source model's forward pass, decode to text, re-tokenize on target side, prepend as embeddings before the projected latent vector. Controlled by hybrid_k parameter (0=disabled). Co-Authored-By: Claude Opus 4.6 --- benchmarks/hotpotqa/pipeline_rosetta.py | 74 ++++++++++++++++++++++++- benchmarks/hotpotqa/run_hotpotqa.py | 2 + 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/benchmarks/hotpotqa/pipeline_rosetta.py b/benchmarks/hotpotqa/pipeline_rosetta.py index da4c3f3..526210c 100644 --- a/benchmarks/hotpotqa/pipeline_rosetta.py +++ b/benchmarks/hotpotqa/pipeline_rosetta.py @@ -18,6 +18,43 @@ from .evaluate import exact_match, extract_answer, token_f1 +def extract_key_tokens( + attention_weights: Any, + input_ids: Any, + tokenizer: Any, + prompt_tokens: int, + k: int = 64, +) -> str: + """Extract top-K important tokens from attention weights. + + Uses attention from the final token (over the full KV-cache) to score + input token importance. Returns decoded text of the top-K tokens. + + Args: + attention_weights: Last layer attention, shape [batch, heads, 1, seq_len] + input_ids: Original input token IDs, shape [1, prompt_tokens] + tokenizer: Source model tokenizer + prompt_tokens: Number of original prompt tokens (before latent steps) + k: Number of tokens to extract + """ + # Average across heads: [seq_len] + attn_scores = attention_weights[0, :, -1, :].mean(dim=0) + + # Only score the original prompt tokens (not latent step positions) + prompt_scores = attn_scores[:prompt_tokens] + + # Select top-K by attention (capped at available tokens) + k = min(k, prompt_tokens) + topk_indices = prompt_scores.topk(k).indices + # Sort by position to maintain reading order + topk_indices = topk_indices.sort().values + + # Extract and decode + key_ids = input_ids[0, topk_indices] + key_text = tokenizer.decode(key_ids, skip_special_tokens=True) + return key_text + + def run_rosetta_pipeline( conn_a: Any, model_a: Any, @@ -37,6 +74,7 @@ def run_rosetta_pipeline( verbose: bool = False, projection_temperature: float = 1.0, num_transfer_states: int = 1, + hybrid_k: int = 0, ) -> Dict: """Run the 2-agent cross-model pipeline on a single HotpotQA problem. @@ -66,6 +104,8 @@ def run_rosetta_pipeline( total_latent_steps += latent_steps attention_entropy = None + key_text = None + hybrid_text_tokens = 0 if num_transfer_states > 1: # Multi-embedding: collect hidden states from all latent steps past_kv, hidden_states = conn_a.generate_latent_steps( @@ -113,6 +153,16 @@ def run_rosetta_pipeline( attn_ent = -(last_attn * attn_log).sum(dim=-1) # [batch, heads] attention_entropy = float(attn_ent.mean()) + # Extract key tokens for hybrid mode + key_text = None + if hybrid_k > 0 and out.attentions: + key_text = extract_key_tokens( + out.attentions[-1], input_ids, tokenizer_a, + prompt_tokens=prompt_tokens, k=hybrid_k, + ) + if verbose: + print(f" [Hybrid] Extracted {hybrid_k} key tokens: {key_text[:100]}...") + # Project to target model space projected, proj_metrics = conn_a.project_hidden_for_cross_model( last_hidden, avp_map, temperature=projection_temperature, @@ -150,6 +200,18 @@ def run_rosetta_pipeline( # --- Agent 2: Answerer on model B (inject projected embed, generate) --- inject_t0 = time.perf_counter() embed_input = rosetta_embeds.to(device).to(model_b.dtype) + + # Hybrid: prepend key text tokens as embeddings before latent + hybrid_text_tokens = 0 + if key_text: + key_ids_b = tokenizer_b.encode(key_text, add_special_tokens=False) + if key_ids_b: + key_ids_tensor = torch.tensor([key_ids_b], device=device) + key_embeds = model_b.get_input_embeddings()(key_ids_tensor) # [1, K, D_tgt] + key_embeds = key_embeds.to(model_b.dtype) + embed_input = torch.cat([key_embeds, embed_input], dim=1) # [1, K+1, D_tgt] + hybrid_text_tokens = len(key_ids_b) + embed_mask = torch.ones( (1, embed_input.shape[1]), dtype=torch.long, device=device, ) @@ -230,8 +292,11 @@ def run_rosetta_pipeline( "hidden_state_norm": float(proj_metrics["hidden_state_norm"].mean()) if "hidden_state_norm" in proj_metrics else None, "nearest_cos_sim": float(proj_metrics["nearest_cos_sim"].mean()) if "nearest_cos_sim" in proj_metrics else None, "attention_entropy": attention_entropy, + "hybrid_k": hybrid_k, + "hybrid_text_tokens": hybrid_text_tokens if hybrid_k > 0 else 0, + "hybrid_key_text": key_text if hybrid_k > 0 else None, "agents": agent_traces, - "mode": "rosetta", + "mode": "hybrid" if hybrid_k > 0 else "rosetta", } @@ -252,12 +317,14 @@ def run_rosetta_benchmark( verbose: bool = False, projection_temperature: float = 1.0, num_transfer_states: int = 1, + hybrid_k: int = 0, ) -> List[Dict]: """Run rosetta-mode pipeline on HotpotQA samples.""" results = [] + mode_label = f"Hybrid K={hybrid_k}" if hybrid_k > 0 else "Rosetta" for i, sample in enumerate(dataset): if verbose: - print(f"\n[Rosetta] Sample {i + 1}/{len(dataset)}: {sample['question'][:80]}...") + print(f"\n[{mode_label}] Sample {i + 1}/{len(dataset)}: {sample['question'][:80]}...") result = run_rosetta_pipeline( conn_a, model_a, tokenizer_a, identity_a, @@ -272,6 +339,7 @@ def run_rosetta_benchmark( verbose=verbose, projection_temperature=projection_temperature, num_transfer_states=num_transfer_states, + hybrid_k=hybrid_k, ) results.append(result) @@ -285,7 +353,7 @@ def run_rosetta_benchmark( correct = sum(1 for r in results if r["exact_match"]) f1s = [r["f1"] for r in results] mean_f1 = sum(f1s) / len(f1s) - print(f" [Rosetta] {i + 1}/{len(dataset)} " + print(f" [{mode_label}] {i + 1}/{len(dataset)} " f"(EM={correct}/{i + 1}, F1={mean_f1:.2f}, {result['wall_time']:.1f}s)", flush=True) diff --git a/benchmarks/hotpotqa/run_hotpotqa.py b/benchmarks/hotpotqa/run_hotpotqa.py index 622e8de..df51cd4 100644 --- a/benchmarks/hotpotqa/run_hotpotqa.py +++ b/benchmarks/hotpotqa/run_hotpotqa.py @@ -272,6 +272,7 @@ def run_benchmark(config: dict) -> dict: output_dir = config.get("output_dir") projection_temperature = config.get("projection_temperature", 1.0) num_transfer_states = config.get("num_transfer_states", 1) + hybrid_k = config.get("hybrid_k", 0) model_b_name = config.get("model_b", "") @@ -383,6 +384,7 @@ def run_benchmark(config: dict) -> dict: temperature=temperature, top_p=top_p, verbose=verbose, projection_temperature=projection_temperature, num_transfer_states=num_transfer_states, + hybrid_k=hybrid_k, ) del model_b, tokenizer_b, connector_b, identity_b From e71bfc177aee97c7a9ad0593797218abb0b86aab Mon Sep 17 00:00:00 2001 From: Stanislav Date: Sat, 14 Mar 2026 20:09:00 +0000 Subject: [PATCH 11/14] Hybrid v2: inject key tokens as text in Agent B prompt, not inputs_embeds MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit inputs_embeds injection had zero effect (model ignores raw embeddings). Now inject key tokens as "Key Context" in the answerer's prompt via input_ids — processed through normal embedding + positional encoding path. Co-Authored-By: Claude Opus 4.6 --- benchmarks/hotpotqa/agents.py | 33 ++++++++++++++++++------- benchmarks/hotpotqa/pipeline_rosetta.py | 21 +++++++--------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/benchmarks/hotpotqa/agents.py b/benchmarks/hotpotqa/agents.py index 4aa298e..7419392 100644 --- a/benchmarks/hotpotqa/agents.py +++ b/benchmarks/hotpotqa/agents.py @@ -63,10 +63,12 @@ def build_latent_prompt( role: str, question: str, paragraphs_text: str = "", + key_context: str = "", ) -> List[Dict[str, str]]: """Build chat messages for latent mode agents. In latent mode, prior context is carried via KV-cache. + key_context: optional key tokens extracted by attention for hybrid mode. """ if role == "finder": user_prompt = ( @@ -80,15 +82,28 @@ def build_latent_prompt( elif role == "answerer": # In latent mode, the Answerer gets the question but NOT the paragraphs. # The Finder's KV-cache carries the model's understanding of the context. - user_prompt = ( - f"You are an Answerer Agent. A Finder agent has already analyzed " - f"the relevant paragraphs for you. Use their analysis to answer " - f"the question.\n\n" - f"## Question: {question}\n\n" - f"Based on the analysis provided, give a short, precise answer " - f"(a few words at most).\n\n" - f"Answer:" - ) + if key_context: + # Hybrid mode: include key tokens as context + user_prompt = ( + f"You are an Answerer Agent. A Finder agent has already analyzed " + f"the relevant paragraphs for you. Use their analysis and the " + f"key context below to answer the question.\n\n" + f"## Key Context: {key_context}\n\n" + f"## Question: {question}\n\n" + f"Based on the analysis and context provided, give a short, precise answer " + f"(a few words at most).\n\n" + f"Answer:" + ) + else: + user_prompt = ( + f"You are an Answerer Agent. A Finder agent has already analyzed " + f"the relevant paragraphs for you. Use their analysis to answer " + f"the question.\n\n" + f"## Question: {question}\n\n" + f"Based on the analysis provided, give a short, precise answer " + f"(a few words at most).\n\n" + f"Answer:" + ) elif role == "decomposer": user_prompt = ( f"You are a Decomposer Agent. Break the following multi-hop question " diff --git a/benchmarks/hotpotqa/pipeline_rosetta.py b/benchmarks/hotpotqa/pipeline_rosetta.py index 526210c..8ea98da 100644 --- a/benchmarks/hotpotqa/pipeline_rosetta.py +++ b/benchmarks/hotpotqa/pipeline_rosetta.py @@ -201,17 +201,6 @@ def run_rosetta_pipeline( inject_t0 = time.perf_counter() embed_input = rosetta_embeds.to(device).to(model_b.dtype) - # Hybrid: prepend key text tokens as embeddings before latent - hybrid_text_tokens = 0 - if key_text: - key_ids_b = tokenizer_b.encode(key_text, add_special_tokens=False) - if key_ids_b: - key_ids_tensor = torch.tensor([key_ids_b], device=device) - key_embeds = model_b.get_input_embeddings()(key_ids_tensor) # [1, K, D_tgt] - key_embeds = key_embeds.to(model_b.dtype) - embed_input = torch.cat([key_embeds, embed_input], dim=1) # [1, K+1, D_tgt] - hybrid_text_tokens = len(key_ids_b) - embed_mask = torch.ones( (1, embed_input.shape[1]), dtype=torch.long, device=device, ) @@ -225,7 +214,15 @@ def run_rosetta_pipeline( past_kv_b = prime_out.past_key_values injection_ms = (time.perf_counter() - inject_t0) * 1000 - messages = build_latent_prompt(answerer.role, question) + # Hybrid: inject key text as context in Agent B's prompt (via input_ids) + hybrid_text_tokens = 0 + if key_text: + hybrid_text_tokens = len(tokenizer_b.encode(key_text, add_special_tokens=False)) + messages = build_latent_prompt( + answerer.role, question, key_context=key_text, + ) + else: + messages = build_latent_prompt(answerer.role, question) prompt_text = render_prompt(tokenizer_b, messages) input_ids, attention_mask = tokenize_prompt(tokenizer_b, prompt_text, device) From a1882f622ba829619fbbddc7466ed22c0a6e4719 Mon Sep 17 00:00:00 2001 From: Stanislav Date: Sat, 14 Mar 2026 20:19:14 +0000 Subject: [PATCH 12/14] Fix hybrid: switch SDPA to eager attention for key token extraction SDPA silently ignores output_attentions=True, so attention weights were never returned. key_text was always None, meaning hybrid mode was effectively running pure rosetta. Temporarily switch to eager attention for the dummy forward pass when hybrid_k > 0. Co-Authored-By: Claude Opus 4.6 --- benchmarks/hotpotqa/pipeline_rosetta.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/benchmarks/hotpotqa/pipeline_rosetta.py b/benchmarks/hotpotqa/pipeline_rosetta.py index 8ea98da..8d75388 100644 --- a/benchmarks/hotpotqa/pipeline_rosetta.py +++ b/benchmarks/hotpotqa/pipeline_rosetta.py @@ -135,6 +135,18 @@ def run_rosetta_pipeline( dummy_mask = torch.ones((1, past_len + 1), dtype=torch.long, device=device) eos_id = tokenizer_a.eos_token_id or 0 dummy_ids = torch.tensor([[eos_id]], device=device) + + # SDPA doesn't support output_attentions — switch to eager temporarily + need_attentions = hybrid_k > 0 + original_attn_impl = None + if need_attentions and hasattr(model_a, 'config'): + original_attn_impl = getattr(model_a.config, '_attn_implementation', None) + model_a.config._attn_implementation = 'eager' + # Also update all attention layers + for module in model_a.modules(): + if hasattr(module, '_attn_implementation'): + module._attn_implementation = 'eager' + with torch.no_grad(): out = model_a( input_ids=dummy_ids, @@ -144,6 +156,14 @@ def run_rosetta_pipeline( output_attentions=True, return_dict=True, ) + + # Restore original attention implementation + if original_attn_impl is not None: + model_a.config._attn_implementation = original_attn_impl + for module in model_a.modules(): + if hasattr(module, '_attn_implementation'): + module._attn_implementation = original_attn_impl + last_hidden = out.hidden_states[-1][:, -1, :] # [1, D_src] # Compute attention entropy from last layer @@ -162,6 +182,8 @@ def run_rosetta_pipeline( ) if verbose: print(f" [Hybrid] Extracted {hybrid_k} key tokens: {key_text[:100]}...") + elif hybrid_k > 0: + print(f" WARNING: output_attentions returned empty — hybrid extraction skipped") # Project to target model space projected, proj_metrics = conn_a.project_hidden_for_cross_model( From aec31ecefe1af308619d93ecea09fb194337082a Mon Sep 17 00:00:00 2001 From: Stanislav Date: Sat, 14 Mar 2026 20:25:57 +0000 Subject: [PATCH 13/14] Fix hybrid: load model with eager attention at from_pretrained time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SDPA attention silently ignores output_attentions=True, so all hybrid experiment runs had no attention weights — key_text was always None. The previous runtime _attn_implementation override didn't work because HuggingFace selects the attention module class at from_pretrained time. Fix: pass attn_implementation="eager" to from_pretrained when hybrid_k > 0. Remove the broken runtime override from pipeline_rosetta.py. Co-Authored-By: Claude Opus 4.6 --- benchmarks/hotpotqa/pipeline_rosetta.py | 18 ------------------ benchmarks/hotpotqa/run_hotpotqa.py | 4 +++- benchmarks/shared/model_utils.py | 15 ++++++++++++--- 3 files changed, 15 insertions(+), 22 deletions(-) diff --git a/benchmarks/hotpotqa/pipeline_rosetta.py b/benchmarks/hotpotqa/pipeline_rosetta.py index 8d75388..16fe205 100644 --- a/benchmarks/hotpotqa/pipeline_rosetta.py +++ b/benchmarks/hotpotqa/pipeline_rosetta.py @@ -136,17 +136,6 @@ def run_rosetta_pipeline( eos_id = tokenizer_a.eos_token_id or 0 dummy_ids = torch.tensor([[eos_id]], device=device) - # SDPA doesn't support output_attentions — switch to eager temporarily - need_attentions = hybrid_k > 0 - original_attn_impl = None - if need_attentions and hasattr(model_a, 'config'): - original_attn_impl = getattr(model_a.config, '_attn_implementation', None) - model_a.config._attn_implementation = 'eager' - # Also update all attention layers - for module in model_a.modules(): - if hasattr(module, '_attn_implementation'): - module._attn_implementation = 'eager' - with torch.no_grad(): out = model_a( input_ids=dummy_ids, @@ -157,13 +146,6 @@ def run_rosetta_pipeline( return_dict=True, ) - # Restore original attention implementation - if original_attn_impl is not None: - model_a.config._attn_implementation = original_attn_impl - for module in model_a.modules(): - if hasattr(module, '_attn_implementation'): - module._attn_implementation = original_attn_impl - last_hidden = out.hidden_states[-1][:, -1, :] # [1, D_src] # Compute attention entropy from last layer diff --git a/benchmarks/hotpotqa/run_hotpotqa.py b/benchmarks/hotpotqa/run_hotpotqa.py index df51cd4..b5626ba 100644 --- a/benchmarks/hotpotqa/run_hotpotqa.py +++ b/benchmarks/hotpotqa/run_hotpotqa.py @@ -298,7 +298,9 @@ def run_benchmark(config: dict) -> dict: print() dataset = load_dataset(max_samples, max_context_tokens) - model, tokenizer, connector, identity = load_model(model_name, device) + # Hybrid mode needs eager attention for output_attentions (SDPA silently ignores it) + attn_impl = "eager" if hybrid_k > 0 else None + model, tokenizer, connector, identity = load_model(model_name, device, attn_implementation=attn_impl) direct_results = None latent_results = None diff --git a/benchmarks/shared/model_utils.py b/benchmarks/shared/model_utils.py index dfbc78e..7dededb 100644 --- a/benchmarks/shared/model_utils.py +++ b/benchmarks/shared/model_utils.py @@ -30,8 +30,13 @@ def auto_device(device: Optional[str]) -> str: return "cpu" -def load_model(model_name: str, device: str): - """Load model and tokenizer, return (model, tokenizer, connector, identity).""" +def load_model(model_name: str, device: str, attn_implementation: Optional[str] = None): + """Load model and tokenizer, return (model, tokenizer, connector, identity). + + Args: + attn_implementation: Override attention implementation (e.g. "eager" for + output_attentions support — SDPA silently ignores it). + """ from transformers import AutoModelForCausalLM, AutoTokenizer from avp.connectors.huggingface import HuggingFaceConnector @@ -45,7 +50,11 @@ def load_model(model_name: str, device: str): dtype = torch.float16 else: dtype = torch.float32 - model = AutoModelForCausalLM.from_pretrained(model_name, dtype=dtype) + kwargs = {} + if attn_implementation is not None: + kwargs["attn_implementation"] = attn_implementation + print(f" Using attention implementation: {attn_implementation}") + model = AutoModelForCausalLM.from_pretrained(model_name, dtype=dtype, **kwargs) tokenizer = AutoTokenizer.from_pretrained(model_name) model.to(device) model.eval() From 4c3607b07e8dea3c812c33a58f93e1228951123e Mon Sep 17 00:00:00 2001 From: Stanislav Date: Sat, 14 Mar 2026 20:30:02 +0000 Subject: [PATCH 14/14] Fix hybrid extraction: skip template tokens to avoid attention sinks Attention-based extraction was picking up system prompt and instruction tokens instead of paragraph content. Added find_content_start() to locate the "## Paragraphs:" marker and zero out template token scores. Co-Authored-By: Claude Opus 4.6 --- benchmarks/hotpotqa/pipeline_rosetta.py | 34 +++++++++++++++++++++---- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/benchmarks/hotpotqa/pipeline_rosetta.py b/benchmarks/hotpotqa/pipeline_rosetta.py index 16fe205..1514cbe 100644 --- a/benchmarks/hotpotqa/pipeline_rosetta.py +++ b/benchmarks/hotpotqa/pipeline_rosetta.py @@ -18,6 +18,26 @@ from .evaluate import exact_match, extract_answer, token_f1 +def find_content_start(input_ids: Any, tokenizer: Any) -> int: + """Find the token position where paragraph content begins. + + Searches for "## Paragraphs:" or "Paragraphs:" marker in the tokenized prompt. + Falls back to skipping the first 20% of tokens (covers system + instruction). + """ + # Decode to find the marker position in text, then map back to tokens + full_text = tokenizer.decode(input_ids[0], skip_special_tokens=False) + for marker in ["## Paragraphs:", "Paragraphs:", "## paragraphs:"]: + marker_pos = full_text.find(marker) + if marker_pos >= 0: + # Count tokens up to the marker (+ marker itself) + prefix = full_text[:marker_pos + len(marker)] + prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) + return len(prefix_tokens) + + # Fallback: skip first 20% of tokens (system + instruction) + return int(input_ids.shape[-1] * 0.2) + + def extract_key_tokens( attention_weights: Any, input_ids: Any, @@ -28,7 +48,8 @@ def extract_key_tokens( """Extract top-K important tokens from attention weights. Uses attention from the final token (over the full KV-cache) to score - input token importance. Returns decoded text of the top-K tokens. + input token importance. Skips system prompt and instruction tokens to + avoid attention sinks. Returns decoded text of the top-K tokens. Args: attention_weights: Last layer attention, shape [batch, heads, 1, seq_len] @@ -40,11 +61,14 @@ def extract_key_tokens( # Average across heads: [seq_len] attn_scores = attention_weights[0, :, -1, :].mean(dim=0) - # Only score the original prompt tokens (not latent step positions) - prompt_scores = attn_scores[:prompt_tokens] + # Only score content tokens (skip system prompt + instruction to avoid attention sinks) + content_start = find_content_start(input_ids, tokenizer) + prompt_scores = attn_scores[:prompt_tokens].clone() + prompt_scores[:content_start] = -float("inf") # Zero out template tokens - # Select top-K by attention (capped at available tokens) - k = min(k, prompt_tokens) + # Select top-K by attention (capped at available content tokens) + content_tokens = prompt_tokens - content_start + k = min(k, max(content_tokens, 1)) topk_indices = prompt_scores.topk(k).indices # Sort by position to maintain reading order topk_indices = topk_indices.sort().values