Gym Cleanup & Continuous Learning GRPO Architecture Plan¶
Date: October 28, 2025
Status: Code review complete, cleanup required before new features
Executive Summary¶
Code Review Verdict: ⚠️ Needs Improvement (Medium-High Risk)
The gym codebase has solid foundations but suffers from technical debt: - ❌ 400+ lines of duplicated code across 8 trainers - ❌ Class name conflicts (trainer.py vs trainer_simple.py) - ❌ 300+ line parameter sprawl in single dataclass - ❌ Unclear file organization (simple/v1/v2/unified variants)
Recommendation: Refactor HIGH PRIORITY items before adding Continuous Learning GRPO (5-7 days investment → long-term maintainability)
Critical Issues Identified¶
🔴 Issue #1: Duplicate Class Names¶
Problem:
src/gym/train/grpo/
├── trainer.py → class GRPOTrainer
└── trainer_simple.py → class GRPOTrainer (SAME NAME!)
Impact: - Import ambiguity - Maintenance nightmare - Confusing for developers
Solution:
# Option A: Rename (quick fix)
class GRPOTrainer # Full-featured
class GRPOTrainerSimple # Minimal
# Option B: Remove simple versions (recommended)
# Move to examples/simple_trainers/ as demo code
🔴 Issue #2: Massive Parameter Sprawl¶
Problem: Single 300+ line dataclass contains ALL RL method parameters
@dataclass
class RLHFArguments:
# PPO params (7)
ppo_buffer_size: int = 1
ppo_epochs: int = 4
# DPO params (9)
pref_beta: float = 0.1
dpo_label_smoothing: float = 0.0
# KTO params (3)
kto_chosen_weight: float = 1.0
# GRPO params (11)
grpo_group_size: int = 8
grpo_beta: float = 0.1
continuous_learning_grpo: bool = False
experience_lib_path: Optional[str] = None
# ... 7 more
# GSPO params (6)
gspo_group_size: int = 8
# Shared (9)
ref_model: Optional[str] = None
# ...
Impact: - Hard to understand which params apply to which method - Easy to set conflicting parameters - No type checking - Growing linearly with each new method
Solution: Hierarchical Dataclasses
@dataclass
class BaseRLHFArguments:
"""Shared parameters for all RLHF methods."""
ref_model: Optional[str] = None
ref_model_adapters: Optional[str] = None
reward_model: Optional[str] = None
beta: float = 0.1 # Common KL penalty
@dataclass
class GRPOArguments(BaseRLHFArguments):
"""Standard GRPO parameters."""
group_size: int = 8
clip_range: float = 0.2
normalize_advantages: bool = True
@dataclass
class ContinuousLearningGRPOArguments(GRPOArguments):
"""Continuous Learning GRPO additional parameters."""
training_free: bool = True
experience_lib_path: Optional[str] = None
llm_api_key: Optional[str] = None
llm_base_url: str = "https://api.deepseek.com/v1"
llm_model: str = "deepseek-chat"
semantic_max_operations: int = 3
rollout_temperature: float = 0.7
use_groundtruth: bool = True
@dataclass
class DPOArguments(BaseRLHFArguments):
"""DPO-specific parameters."""
ftx: float = 0.0
label_smoothing: float = 0.0
loss_type: Literal["sigmoid", "hinge", "ipo"] = "sigmoid"
Benefits: - ✅ Type-safe parameter checking - ✅ Clear inheritance hierarchy - ✅ IDE autocomplete works - ✅ Only relevant params visible per method
🔴 Issue #3: Code Duplication Across All Trainers¶
400+ lines duplicated across 8 trainers:
Pattern 1: Reference Model Setup (repeated 4x)
# dpo/trainer.py, kto/trainer.py, grpo/trainer.py, gspo/trainer.py
if ref_model is not None:
if self.is_deepspeed_enabled:
if not (getattr(ref_model, "is_loaded_in_8bit", False) or
getattr(ref_model, "is_loaded_in_4bit", False)):
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(
self.ref_model, evaluation_mode=True
)
Pattern 2: Optimizer/Scheduler Creation (repeated 8x)
# ALL trainers
@override
def create_optimizer(self):
if self.optimizer is None:
self.optimizer = create_custom_optimizer(...)
return super().create_optimizer()
@override
def create_scheduler(self, num_training_steps, optimizer=None):
create_custom_scheduler(...)
return super().create_scheduler(...)
Pattern 3: Dropout Disabling (repeated 4x)
if disable_dropout:
disable_dropout_in_model(model)
if ref_model is not None:
disable_dropout_in_model(ref_model)
Solution: Create BaseRLHFTrainer
# src/gym/train/base_rlhf_trainer.py
from abc import ABC, abstractmethod
from transformers import Trainer
class BaseRLHFTrainer(Trainer, ABC):
"""Base class for all RLHF trainers (PPO, DPO, KTO, GRPO, GSPO).
Handles common functionality:
- Reference model setup
- Optimizer/scheduler creation
- Dropout management
- Processor callbacks
- Metric storage
"""
def __init__(
self,
model,
ref_model,
args,
finetuning_args,
processor=None,
disable_dropout=True,
**kwargs
):
self.finetuning_args = finetuning_args
self.ref_model = ref_model
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# Disable dropout once (not per trainer)
if disable_dropout:
self._disable_dropout(model, ref_model)
super().__init__(model=model, args=args, **kwargs)
# Setup reference model once
if ref_model is not None:
self._setup_ref_model()
# Add processor callback once
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
def _setup_ref_model(self):
"""Setup reference model for DeepSpeed or standard training."""
if self.is_deepspeed_enabled:
if not (getattr(self.ref_model, "is_loaded_in_8bit", False) or
getattr(self.ref_model, "is_loaded_in_4bit", False)):
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(
self.ref_model, evaluation_mode=True
)
self.ref_model.eval()
def _disable_dropout(self, model, ref_model):
"""Disable dropout in model and reference model."""
from trl.trainer import disable_dropout_in_model
disable_dropout_in_model(model)
if ref_model is not None:
disable_dropout_in_model(ref_model)
@override
def create_optimizer(self):
if self.optimizer is None:
self.optimizer = create_custom_optimizer(
self.model, self.args, self.finetuning_args
)
return super().create_optimizer()
@override
def create_scheduler(self, num_training_steps, optimizer=None):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@abstractmethod
def compute_loss(self, model, inputs, return_outputs=False):
"""Each trainer implements its own loss computation."""
pass
Then each trainer becomes simple:
class GRPOTrainer(BaseRLHFTrainer):
"""Group Relative Policy Optimization Trainer."""
def __init__(
self,
finetuning_args: GRPOArguments,
...
):
# GRPO-specific parameters only
self.group_size = finetuning_args.group_size
self.clip_range = finetuning_args.clip_range
self.normalize_advantages = finetuning_args.normalize_advantages
# Parent handles: ref_model, dropout, optimizer, scheduler
super().__init__(
model=model,
ref_model=ref_model,
args=args,
finetuning_args=finetuning_args,
...
)
def compute_loss(self, model, inputs, return_outputs=False):
"""GRPO-specific loss computation."""
# Only GRPO logic here
...
Lines Saved: 400+ lines across 8 trainers
Architectural Decision: Continuous Learning GRPO¶
Option A: Separate Class (RECOMMENDED)¶
class GRPOTrainer(BaseRLHFTrainer):
"""Standard GRPO with parameter updates."""
def __init__(self, finetuning_args: GRPOArguments, ...):
self.group_size = finetuning_args.group_size
super().__init__(...)
def compute_loss(self, model, inputs, return_outputs=False):
# 1. Generate rollouts
# 2. Compute numerical advantages
# 3. Compute policy gradient loss
# 4. Return loss for optimizer.step()
return loss
class ContinuousLearningGRPOTrainer(BaseRLHFTrainer):
"""Continuous Learning GRPO with experience library updates."""
def __init__(self, finetuning_args: ContinuousLearningGRPOArguments, ...):
self.group_size = finetuning_args.group_size
# Training-free specific setup
self.experience_manager = ExperienceManager(
checkpoint_path=finetuning_args.experience_lib_path
)
llm_client = LLMClient(
api_key=finetuning_args.llm_api_key,
base_url=finetuning_args.llm_base_url,
model=finetuning_args.llm_model
)
self.semantic_extractor = SemanticExtractor(
llm_client=llm_client,
max_operations=finetuning_args.semantic_max_operations
)
super().__init__(...)
# Freeze model
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
def compute_loss(self, model, inputs, return_outputs=False):
# 1. Inject experiences into prompts
# 2. Generate rollouts with experiences
# 3. Extract semantic advantages
# 4. Update experience library
# 5. Return zero loss (no gradient update!)
return torch.tensor(0.0, device=model.device)
Why Option A (Separate Class) is BETTER:
✅ Clear Separation of Concerns - Standard GRPO: numerical advantages → gradient updates - Continuous Learning GRPO: semantic advantages → experience updates - Two fundamentally different learning paradigms
✅ Type Safety - GRPOTrainer requires GRPOArguments - ContinuousLearningGRPOTrainer requires ContinuousLearningGRPOArguments - Compiler catches configuration errors
✅ Easier to Understand - New developers see two distinct classes - No confusing if training_free: branches - Clear which class to use
✅ Better Testing - Test standard GRPO separately - Test Continuous Learning GRPO separately - No interaction between the two
✅ Follows Single Responsibility Principle - Each class has one reason to change - Standard GRPO changes don't affect training-free - Training-free changes don't affect standard
✅ Easier to Extend - Want to add Training-Free GSPO? Create ContinuousLearningGSPOTrainer - Want training-free variants of other methods? Same pattern - Scales better than flags
✅ Less Cognitive Load - No mental overhead tracking which mode is active - Method signatures are simpler - Fewer conditional branches
Option B: Flag-Based (NOT RECOMMENDED)¶
class GRPOTrainer(BaseRLHFTrainer):
def __init__(self, finetuning_args, ...):
self.training_free = finetuning_args.continuous_learning_grpo
if self.training_free:
# Setup experience manager, semantic extractor
# Freeze model
...
else:
# Standard setup
...
def compute_loss(self, model, inputs, return_outputs=False):
if self.training_free:
return self._compute_loss_training_free(...)
else:
return self._compute_loss_standard(...)
Why Option B is WORSE:
❌ Poor Separation of Concerns - Single class doing two very different things - Violates Single Responsibility Principle
❌ Type Safety Issues - Can't enforce ContinuousLearningGRPOArguments at compile time - Runtime errors from missing parameters
❌ Testing Complexity - Need to test both modes in single class - State can leak between modes - More edge cases to cover
❌ Cognitive Overhead - Developers must track which mode is active - More if branches = more mental load - Harder to reason about behavior
❌ Harder to Extend - Adding more variants compounds complexity - Class grows linearly with variants
Directory Structure (Proposed)¶
src/gym/train/
├── base_rlhf_trainer.py # NEW: Base class for all RLHF
├── sft/
│ └── trainer.py
├── ppo/
│ └── trainer.py
├── dpo/
│ └── trainer.py
├── kto/
│ └── trainer.py
├── grpo/
│ ├── __init__.py
│ ├── trainer.py # Standard GRPO (parameter updates)
│ ├── workflow.py
│ └── training_free/ # NEW: Continuous Learning GRPO
│ ├── __init__.py
│ ├── trainer.py # ContinuousLearningGRPOTrainer
│ ├── workflow.py
│ ├── experience_manager.py
│ ├── semantic_extractor.py
│ ├── api_model_adapter.py
│ └── domains/
│ ├── __init__.py
│ ├── base.py
│ ├── math/
│ │ ├── dataset.py
│ │ ├── verify.py
│ │ └── prompts.py
│ └── web/
│ ├── dataset.py
│ ├── verify.py
│ └── prompts.py
└── gspo/
├── __init__.py
└── trainer.py # Standard GSPO
Naming Convention: - GRPOTrainer - Standard GRPO with parameter updates - ContinuousLearningGRPOTrainer - Continuous Learning GRPO with experience updates - GSPOTrainer - Standard GSPO - ContinuousLearningGSPOTrainer - Future extension
Import Convention:
# Standard GRPO
from gym.train.grpo import GRPOTrainer
# Continuous Learning GRPO
from gym.train.grpo.continuous_learning import ContinuousLearningGRPOTrainer
Parameter Configuration (Proposed)¶
# src/gym/hparams/finetuning_args.py
@dataclass
class BaseRLHFArguments:
"""Shared parameters for all RLHF methods."""
ref_model: Optional[str] = None
ref_model_adapters: Optional[str] = None
ref_model_quantization_bit: Optional[int] = None
reward_model: Optional[str] = None
reward_model_adapters: Optional[str] = None
reward_model_quantization_bit: Optional[int] = None
reward_model_type: Literal["lora", "full", "api"] = "lora"
@dataclass
class GRPOArguments(BaseRLHFArguments):
"""Standard GRPO parameters."""
group_size: int = field(
default=8,
metadata={"help": "Group size for relative advantage computation."}
)
beta: float = field(
default=0.1,
metadata={"help": "Beta coefficient for KL penalty."}
)
clip_range: float = field(
default=0.2,
metadata={"help": "Clipping range for importance ratios."}
)
normalize_advantages: bool = field(
default=True,
metadata={"help": "Whether to normalize advantages within groups."}
)
@dataclass
class ContinuousLearningGRPOArguments(GRPOArguments):
"""Continuous Learning GRPO parameters."""
# Inherit group_size, beta, clip_range, normalize_advantages
experience_lib_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save/load experience library."}
)
llm_api_key: Optional[str] = field(
default=None,
metadata={"help": "API key for LLM semantic extraction (DeepSeek/OpenAI)."}
)
llm_base_url: str = field(
default="https://api.deepseek.com/v1",
metadata={"help": "Base URL for LLM API."}
)
llm_model: str = field(
default="deepseek-chat",
metadata={"help": "LLM model name for semantic extraction."}
)
semantic_max_operations: int = field(
default=3,
metadata={"help": "Max operations per group critique."}
)
rollout_temperature: float = field(
default=0.7,
metadata={"help": "Temperature for rollout generation."}
)
use_groundtruth: bool = field(
default=True,
metadata={"help": "Use ground truth in semantic extraction."}
)
@dataclass
class FinetuningArguments:
"""Top-level finetuning arguments."""
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "grpo", "gspo", "grpo_free"] = field(
default="sft",
metadata={"help": "Training stage to perform."}
)
# Method-specific arguments will be loaded based on stage
# No more massive RLHFArguments with all methods mixed
Configuration Usage:
# configs/grpo_standard.yaml
stage: grpo
finetuning_type: lora
grpo:
group_size: 8
beta: 0.1
clip_range: 0.2
normalize_advantages: true
# configs/grpo_training_free.yaml
stage: grpo_free
finetuning_type: lora
grpo:
group_size: 5
beta: 0.1
training_free:
experience_lib_path: ./output/experiences.json
llm_api_key: ${DEEPSEEK_API_KEY}
llm_base_url: https://api.deepseek.com/v1
llm_model: deepseek-chat
semantic_max_operations: 3
rollout_temperature: 0.7
use_groundtruth: true
Implementation Roadmap¶
Phase 1: Cleanup (Week 1)¶
Goal: Clean foundation for Continuous Learning GRPO
Day 1-2: Create BaseRLHFTrainer¶
- Create
src/gym/train/base_rlhf_trainer.py - Extract common code (ref_model, optimizer, scheduler)
- Add comprehensive tests
- Update all 8 trainers to inherit from it
Day 3: Refactor Parameter Classes¶
- Create hierarchical dataclasses
- Split RLHFArguments into method-specific classes
- Update all trainers to use new structure
- Test configuration loading
Day 4-5: Resolve File Conflicts¶
- Rename or remove
trainer_simple.pyfiles - Clean up quantization module organization
- Update imports and documentation
Deliverable: Clean, DRY codebase ready for new features
Phase 2: Continuous Learning GRPO Implementation (Week 2-3)¶
Goal: Implement Continuous Learning GRPO with clean architecture
Day 1-2: Create ContinuousLearningGRPOTrainer¶
- Create
src/gym/train/grpo/training_free/trainer.py - Implement
compute_loss()for training-free mode - Add experience injection logic
- Implement rollout generation
Day 3-4: Domain Modules¶
- Create
domains/math/module - dataset.py (AIME loader)
- verify.py (correctness checker)
- prompts.py (math templates)
- Create
domains/web/module
Day 5: Testing¶
- Unit tests for ContinuousLearningGRPOTrainer
- Integration tests with toy dataset
- Test experience library persistence
Deliverable: Working Continuous Learning GRPO trainer
Phase 3: End-to-End Testing (Week 4)¶
Goal: Validate full workflow
Day 1-2: Example Scripts¶
- Create
scripts/train_grpo_free_math.py - Create YAML configs
- Test with API model (DeepSeek)
Day 3-4: Benchmarking¶
- Run 3-epoch training on AIME24 (100 samples)
- Validate experience library growth
- Compare with Tencent baseline
- Cost analysis
Day 5: Documentation¶
- Update README with Continuous Learning GRPO
- Add usage examples
- Create troubleshooting guide
Deliverable: Production-ready Continuous Learning GRPO
Testing Strategy¶
Unit Tests¶
# tests/train/test_base_rlhf_trainer.py
def test_ref_model_setup():
"""Test reference model setup works for all trainers."""
...
def test_optimizer_creation():
"""Test custom optimizer creation."""
...
# tests/train/test_continuous_learning_grpo_trainer.py
def test_experience_injection():
"""Test experience library injection into prompts."""
...
def test_no_parameter_updates():
"""Verify model weights unchanged."""
...
def test_semantic_advantage_extraction():
"""Test LLM-based advantage extraction."""
...
Integration Tests¶
# tests/train/test_grpo_training_free_integration.py
def test_full_training_workflow():
"""Run 2 epochs on toy dataset."""
trainer = ContinuousLearningGRPOTrainer(...)
trainer.train()
# Verify results
assert Path("output/experiences.json").exists()
experiences = json.load(open("output/experiences.json"))
assert len(experiences) > 0
Success Criteria¶
Functional¶
-
BaseRLHFTrainereliminates 400+ lines of duplication - All 8 trainers inherit from
BaseRLHFTrainer -
ContinuousLearningGRPOTrainerruns 3-epoch training - Experience library grows (50-200 experiences)
- Zero parameter updates (model frozen)
- Checkpoint save/load works
Quality¶
- All tests passing (100% coverage for new code)
- No code duplication
- Clear separation of concerns
- Type-safe parameter handling
- Comprehensive documentation
Performance¶
- AIME24 accuracy ≥75% (target: 82.7%)
- Training cost ≤$25 (target: $18)
- Training time ≤8 hours (target: 6 hours)
Conclusion¶
Recommended Approach:
- ✅ Refactor First (Week 1)
- Create
BaseRLHFTrainer - Refactor parameter classes
-
Resolve file conflicts
-
✅ Implement Continuous Learning GRPO (Week 2-3)
- Create
ContinuousLearningGRPOTrainer(separate class) - Implement domain modules
-
Comprehensive testing
-
✅ Validate & Document (Week 4)
- End-to-end testing
- Benchmarking
- Documentation
Why This Approach: - ✅ Clean foundation before new features - ✅ Eliminate 400+ lines of duplication - ✅ Clear separation: GRPOTrainer vs ContinuousLearningGRPOTrainer - ✅ Type-safe, maintainable, extensible - ✅ 1 week cleanup investment → long-term maintainability
Estimated Timeline: 4 weeks total (1 week cleanup + 3 weeks implementation)
Next Step: Approve cleanup plan and proceed with Phase 1 (BaseRLHFTrainer creation)?