Skip to content

Gym Cleanup & Continuous Learning GRPO Architecture Plan

Date: October 28, 2025
Status: Code review complete, cleanup required before new features

Executive Summary

Code Review Verdict: ⚠️ Needs Improvement (Medium-High Risk)

The gym codebase has solid foundations but suffers from technical debt: - ❌ 400+ lines of duplicated code across 8 trainers - ❌ Class name conflicts (trainer.py vs trainer_simple.py) - ❌ 300+ line parameter sprawl in single dataclass - ❌ Unclear file organization (simple/v1/v2/unified variants)

Recommendation: Refactor HIGH PRIORITY items before adding Continuous Learning GRPO (5-7 days investment → long-term maintainability)


Critical Issues Identified

🔴 Issue #1: Duplicate Class Names

Problem:

src/gym/train/grpo/
├── trainer.py          → class GRPOTrainer
└── trainer_simple.py   → class GRPOTrainer (SAME NAME!)

Impact: - Import ambiguity - Maintenance nightmare - Confusing for developers

Solution:

# Option A: Rename (quick fix)
class GRPOTrainer           # Full-featured
class GRPOTrainerSimple     # Minimal

# Option B: Remove simple versions (recommended)
# Move to examples/simple_trainers/ as demo code


🔴 Issue #2: Massive Parameter Sprawl

Problem: Single 300+ line dataclass contains ALL RL method parameters

@dataclass
class RLHFArguments:
    # PPO params (7)
    ppo_buffer_size: int = 1
    ppo_epochs: int = 4

    # DPO params (9)
    pref_beta: float = 0.1
    dpo_label_smoothing: float = 0.0

    # KTO params (3)
    kto_chosen_weight: float = 1.0

    # GRPO params (11)
    grpo_group_size: int = 8
    grpo_beta: float = 0.1
    continuous_learning_grpo: bool = False
    experience_lib_path: Optional[str] = None
    # ... 7 more

    # GSPO params (6)
    gspo_group_size: int = 8

    # Shared (9)
    ref_model: Optional[str] = None
    # ...

Impact: - Hard to understand which params apply to which method - Easy to set conflicting parameters - No type checking - Growing linearly with each new method

Solution: Hierarchical Dataclasses

@dataclass
class BaseRLHFArguments:
    """Shared parameters for all RLHF methods."""
    ref_model: Optional[str] = None
    ref_model_adapters: Optional[str] = None
    reward_model: Optional[str] = None
    beta: float = 0.1  # Common KL penalty

@dataclass
class GRPOArguments(BaseRLHFArguments):
    """Standard GRPO parameters."""
    group_size: int = 8
    clip_range: float = 0.2
    normalize_advantages: bool = True

@dataclass  
class ContinuousLearningGRPOArguments(GRPOArguments):
    """Continuous Learning GRPO additional parameters."""
    training_free: bool = True
    experience_lib_path: Optional[str] = None
    llm_api_key: Optional[str] = None
    llm_base_url: str = "https://api.deepseek.com/v1"
    llm_model: str = "deepseek-chat"
    semantic_max_operations: int = 3
    rollout_temperature: float = 0.7
    use_groundtruth: bool = True

@dataclass
class DPOArguments(BaseRLHFArguments):
    """DPO-specific parameters."""
    ftx: float = 0.0
    label_smoothing: float = 0.0
    loss_type: Literal["sigmoid", "hinge", "ipo"] = "sigmoid"

Benefits: - ✅ Type-safe parameter checking - ✅ Clear inheritance hierarchy - ✅ IDE autocomplete works - ✅ Only relevant params visible per method


🔴 Issue #3: Code Duplication Across All Trainers

400+ lines duplicated across 8 trainers:

Pattern 1: Reference Model Setup (repeated 4x)

# dpo/trainer.py, kto/trainer.py, grpo/trainer.py, gspo/trainer.py
if ref_model is not None:
    if self.is_deepspeed_enabled:
        if not (getattr(ref_model, "is_loaded_in_8bit", False) or 
                getattr(ref_model, "is_loaded_in_4bit", False)):
            self.ref_model = self._prepare_deepspeed(self.ref_model)
    else:
        self.ref_model = self.accelerator.prepare_model(
            self.ref_model, evaluation_mode=True
        )

Pattern 2: Optimizer/Scheduler Creation (repeated 8x)

# ALL trainers
@override
def create_optimizer(self):
    if self.optimizer is None:
        self.optimizer = create_custom_optimizer(...)
    return super().create_optimizer()

@override
def create_scheduler(self, num_training_steps, optimizer=None):
    create_custom_scheduler(...)
    return super().create_scheduler(...)

Pattern 3: Dropout Disabling (repeated 4x)

if disable_dropout:
    disable_dropout_in_model(model)
    if ref_model is not None:
        disable_dropout_in_model(ref_model)

Solution: Create BaseRLHFTrainer

# src/gym/train/base_rlhf_trainer.py
from abc import ABC, abstractmethod
from transformers import Trainer

class BaseRLHFTrainer(Trainer, ABC):
    """Base class for all RLHF trainers (PPO, DPO, KTO, GRPO, GSPO).

    Handles common functionality:
    - Reference model setup
    - Optimizer/scheduler creation
    - Dropout management
    - Processor callbacks
    - Metric storage
    """

    def __init__(
        self,
        model,
        ref_model,
        args,
        finetuning_args,
        processor=None,
        disable_dropout=True,
        **kwargs
    ):
        self.finetuning_args = finetuning_args
        self.ref_model = ref_model
        self._stored_metrics = defaultdict(lambda: defaultdict(list))

        # Disable dropout once (not per trainer)
        if disable_dropout:
            self._disable_dropout(model, ref_model)

        super().__init__(model=model, args=args, **kwargs)

        # Setup reference model once
        if ref_model is not None:
            self._setup_ref_model()

        # Add processor callback once
        if processor is not None:
            self.add_callback(SaveProcessorCallback(processor))

    def _setup_ref_model(self):
        """Setup reference model for DeepSpeed or standard training."""
        if self.is_deepspeed_enabled:
            if not (getattr(self.ref_model, "is_loaded_in_8bit", False) or
                    getattr(self.ref_model, "is_loaded_in_4bit", False)):
                self.ref_model = self._prepare_deepspeed(self.ref_model)
        else:
            self.ref_model = self.accelerator.prepare_model(
                self.ref_model, evaluation_mode=True
            )
            self.ref_model.eval()

    def _disable_dropout(self, model, ref_model):
        """Disable dropout in model and reference model."""
        from trl.trainer import disable_dropout_in_model
        disable_dropout_in_model(model)
        if ref_model is not None:
            disable_dropout_in_model(ref_model)

    @override
    def create_optimizer(self):
        if self.optimizer is None:
            self.optimizer = create_custom_optimizer(
                self.model, self.args, self.finetuning_args
            )
        return super().create_optimizer()

    @override
    def create_scheduler(self, num_training_steps, optimizer=None):
        create_custom_scheduler(self.args, num_training_steps, optimizer)
        return super().create_scheduler(num_training_steps, optimizer)

    @abstractmethod
    def compute_loss(self, model, inputs, return_outputs=False):
        """Each trainer implements its own loss computation."""
        pass

Then each trainer becomes simple:

class GRPOTrainer(BaseRLHFTrainer):
    """Group Relative Policy Optimization Trainer."""

    def __init__(
        self,
        finetuning_args: GRPOArguments,
        ...
    ):
        # GRPO-specific parameters only
        self.group_size = finetuning_args.group_size
        self.clip_range = finetuning_args.clip_range
        self.normalize_advantages = finetuning_args.normalize_advantages

        # Parent handles: ref_model, dropout, optimizer, scheduler
        super().__init__(
            model=model,
            ref_model=ref_model,
            args=args,
            finetuning_args=finetuning_args,
            ...
        )

    def compute_loss(self, model, inputs, return_outputs=False):
        """GRPO-specific loss computation."""
        # Only GRPO logic here
        ...

Lines Saved: 400+ lines across 8 trainers


Architectural Decision: Continuous Learning GRPO

class GRPOTrainer(BaseRLHFTrainer):
    """Standard GRPO with parameter updates."""

    def __init__(self, finetuning_args: GRPOArguments, ...):
        self.group_size = finetuning_args.group_size
        super().__init__(...)

    def compute_loss(self, model, inputs, return_outputs=False):
        # 1. Generate rollouts
        # 2. Compute numerical advantages
        # 3. Compute policy gradient loss
        # 4. Return loss for optimizer.step()
        return loss


class ContinuousLearningGRPOTrainer(BaseRLHFTrainer):
    """Continuous Learning GRPO with experience library updates."""

    def __init__(self, finetuning_args: ContinuousLearningGRPOArguments, ...):
        self.group_size = finetuning_args.group_size

        # Training-free specific setup
        self.experience_manager = ExperienceManager(
            checkpoint_path=finetuning_args.experience_lib_path
        )

        llm_client = LLMClient(
            api_key=finetuning_args.llm_api_key,
            base_url=finetuning_args.llm_base_url,
            model=finetuning_args.llm_model
        )

        self.semantic_extractor = SemanticExtractor(
            llm_client=llm_client,
            max_operations=finetuning_args.semantic_max_operations
        )

        super().__init__(...)

        # Freeze model
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

    def compute_loss(self, model, inputs, return_outputs=False):
        # 1. Inject experiences into prompts
        # 2. Generate rollouts with experiences
        # 3. Extract semantic advantages
        # 4. Update experience library
        # 5. Return zero loss (no gradient update!)
        return torch.tensor(0.0, device=model.device)

Why Option A (Separate Class) is BETTER:

Clear Separation of Concerns - Standard GRPO: numerical advantages → gradient updates - Continuous Learning GRPO: semantic advantages → experience updates - Two fundamentally different learning paradigms

Type Safety - GRPOTrainer requires GRPOArguments - ContinuousLearningGRPOTrainer requires ContinuousLearningGRPOArguments - Compiler catches configuration errors

Easier to Understand - New developers see two distinct classes - No confusing if training_free: branches - Clear which class to use

Better Testing - Test standard GRPO separately - Test Continuous Learning GRPO separately - No interaction between the two

Follows Single Responsibility Principle - Each class has one reason to change - Standard GRPO changes don't affect training-free - Training-free changes don't affect standard

Easier to Extend - Want to add Training-Free GSPO? Create ContinuousLearningGSPOTrainer - Want training-free variants of other methods? Same pattern - Scales better than flags

Less Cognitive Load - No mental overhead tracking which mode is active - Method signatures are simpler - Fewer conditional branches


class GRPOTrainer(BaseRLHFTrainer):
    def __init__(self, finetuning_args, ...):
        self.training_free = finetuning_args.continuous_learning_grpo

        if self.training_free:
            # Setup experience manager, semantic extractor
            # Freeze model
            ...
        else:
            # Standard setup
            ...

    def compute_loss(self, model, inputs, return_outputs=False):
        if self.training_free:
            return self._compute_loss_training_free(...)
        else:
            return self._compute_loss_standard(...)

Why Option B is WORSE:

Poor Separation of Concerns - Single class doing two very different things - Violates Single Responsibility Principle

Type Safety Issues - Can't enforce ContinuousLearningGRPOArguments at compile time - Runtime errors from missing parameters

Testing Complexity - Need to test both modes in single class - State can leak between modes - More edge cases to cover

Cognitive Overhead - Developers must track which mode is active - More if branches = more mental load - Harder to reason about behavior

Harder to Extend - Adding more variants compounds complexity - Class grows linearly with variants


Directory Structure (Proposed)

src/gym/train/
├── base_rlhf_trainer.py           # NEW: Base class for all RLHF
├── sft/
│   └── trainer.py
├── ppo/
│   └── trainer.py
├── dpo/
│   └── trainer.py
├── kto/
│   └── trainer.py
├── grpo/
│   ├── __init__.py
│   ├── trainer.py                 # Standard GRPO (parameter updates)
│   ├── workflow.py
│   └── training_free/             # NEW: Continuous Learning GRPO
│       ├── __init__.py
│       ├── trainer.py             # ContinuousLearningGRPOTrainer
│       ├── workflow.py
│       ├── experience_manager.py
│       ├── semantic_extractor.py
│       ├── api_model_adapter.py
│       └── domains/
│           ├── __init__.py
│           ├── base.py
│           ├── math/
│           │   ├── dataset.py
│           │   ├── verify.py
│           │   └── prompts.py
│           └── web/
│               ├── dataset.py
│               ├── verify.py
│               └── prompts.py
└── gspo/
    ├── __init__.py
    └── trainer.py                 # Standard GSPO

Naming Convention: - GRPOTrainer - Standard GRPO with parameter updates - ContinuousLearningGRPOTrainer - Continuous Learning GRPO with experience updates - GSPOTrainer - Standard GSPO - ContinuousLearningGSPOTrainer - Future extension

Import Convention:

# Standard GRPO
from gym.train.grpo import GRPOTrainer

# Continuous Learning GRPO
from gym.train.grpo.continuous_learning import ContinuousLearningGRPOTrainer


Parameter Configuration (Proposed)

# src/gym/hparams/finetuning_args.py

@dataclass
class BaseRLHFArguments:
    """Shared parameters for all RLHF methods."""
    ref_model: Optional[str] = None
    ref_model_adapters: Optional[str] = None
    ref_model_quantization_bit: Optional[int] = None
    reward_model: Optional[str] = None
    reward_model_adapters: Optional[str] = None
    reward_model_quantization_bit: Optional[int] = None
    reward_model_type: Literal["lora", "full", "api"] = "lora"

@dataclass
class GRPOArguments(BaseRLHFArguments):
    """Standard GRPO parameters."""
    group_size: int = field(
        default=8,
        metadata={"help": "Group size for relative advantage computation."}
    )
    beta: float = field(
        default=0.1,
        metadata={"help": "Beta coefficient for KL penalty."}
    )
    clip_range: float = field(
        default=0.2,
        metadata={"help": "Clipping range for importance ratios."}
    )
    normalize_advantages: bool = field(
        default=True,
        metadata={"help": "Whether to normalize advantages within groups."}
    )

@dataclass
class ContinuousLearningGRPOArguments(GRPOArguments):
    """Continuous Learning GRPO parameters."""
    # Inherit group_size, beta, clip_range, normalize_advantages

    experience_lib_path: Optional[str] = field(
        default=None,
        metadata={"help": "Path to save/load experience library."}
    )
    llm_api_key: Optional[str] = field(
        default=None,
        metadata={"help": "API key for LLM semantic extraction (DeepSeek/OpenAI)."}
    )
    llm_base_url: str = field(
        default="https://api.deepseek.com/v1",
        metadata={"help": "Base URL for LLM API."}
    )
    llm_model: str = field(
        default="deepseek-chat",
        metadata={"help": "LLM model name for semantic extraction."}
    )
    semantic_max_operations: int = field(
        default=3,
        metadata={"help": "Max operations per group critique."}
    )
    rollout_temperature: float = field(
        default=0.7,
        metadata={"help": "Temperature for rollout generation."}
    )
    use_groundtruth: bool = field(
        default=True,
        metadata={"help": "Use ground truth in semantic extraction."}
    )

@dataclass
class FinetuningArguments:
    """Top-level finetuning arguments."""
    stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "grpo", "gspo", "grpo_free"] = field(
        default="sft",
        metadata={"help": "Training stage to perform."}
    )

    # Method-specific arguments will be loaded based on stage
    # No more massive RLHFArguments with all methods mixed

Configuration Usage:

# configs/grpo_standard.yaml
stage: grpo
finetuning_type: lora

grpo:
  group_size: 8
  beta: 0.1
  clip_range: 0.2
  normalize_advantages: true

# configs/grpo_training_free.yaml
stage: grpo_free
finetuning_type: lora

grpo:
  group_size: 5
  beta: 0.1

training_free:
  experience_lib_path: ./output/experiences.json
  llm_api_key: ${DEEPSEEK_API_KEY}
  llm_base_url: https://api.deepseek.com/v1
  llm_model: deepseek-chat
  semantic_max_operations: 3
  rollout_temperature: 0.7
  use_groundtruth: true

Implementation Roadmap

Phase 1: Cleanup (Week 1)

Goal: Clean foundation for Continuous Learning GRPO

Day 1-2: Create BaseRLHFTrainer

  • Create src/gym/train/base_rlhf_trainer.py
  • Extract common code (ref_model, optimizer, scheduler)
  • Add comprehensive tests
  • Update all 8 trainers to inherit from it

Day 3: Refactor Parameter Classes

  • Create hierarchical dataclasses
  • Split RLHFArguments into method-specific classes
  • Update all trainers to use new structure
  • Test configuration loading

Day 4-5: Resolve File Conflicts

  • Rename or remove trainer_simple.py files
  • Clean up quantization module organization
  • Update imports and documentation

Deliverable: Clean, DRY codebase ready for new features


Phase 2: Continuous Learning GRPO Implementation (Week 2-3)

Goal: Implement Continuous Learning GRPO with clean architecture

Day 1-2: Create ContinuousLearningGRPOTrainer

  • Create src/gym/train/grpo/training_free/trainer.py
  • Implement compute_loss() for training-free mode
  • Add experience injection logic
  • Implement rollout generation

Day 3-4: Domain Modules

  • Create domains/math/ module
  • dataset.py (AIME loader)
  • verify.py (correctness checker)
  • prompts.py (math templates)
  • Create domains/web/ module

Day 5: Testing

  • Unit tests for ContinuousLearningGRPOTrainer
  • Integration tests with toy dataset
  • Test experience library persistence

Deliverable: Working Continuous Learning GRPO trainer


Phase 3: End-to-End Testing (Week 4)

Goal: Validate full workflow

Day 1-2: Example Scripts

  • Create scripts/train_grpo_free_math.py
  • Create YAML configs
  • Test with API model (DeepSeek)

Day 3-4: Benchmarking

  • Run 3-epoch training on AIME24 (100 samples)
  • Validate experience library growth
  • Compare with Tencent baseline
  • Cost analysis

Day 5: Documentation

  • Update README with Continuous Learning GRPO
  • Add usage examples
  • Create troubleshooting guide

Deliverable: Production-ready Continuous Learning GRPO


Testing Strategy

Unit Tests

# tests/train/test_base_rlhf_trainer.py
def test_ref_model_setup():
    """Test reference model setup works for all trainers."""
    ...

def test_optimizer_creation():
    """Test custom optimizer creation."""
    ...

# tests/train/test_continuous_learning_grpo_trainer.py
def test_experience_injection():
    """Test experience library injection into prompts."""
    ...

def test_no_parameter_updates():
    """Verify model weights unchanged."""
    ...

def test_semantic_advantage_extraction():
    """Test LLM-based advantage extraction."""
    ...

Integration Tests

# tests/train/test_grpo_training_free_integration.py
def test_full_training_workflow():
    """Run 2 epochs on toy dataset."""
    trainer = ContinuousLearningGRPOTrainer(...)
    trainer.train()

    # Verify results
    assert Path("output/experiences.json").exists()
    experiences = json.load(open("output/experiences.json"))
    assert len(experiences) > 0

Success Criteria

Functional

  • BaseRLHFTrainer eliminates 400+ lines of duplication
  • All 8 trainers inherit from BaseRLHFTrainer
  • ContinuousLearningGRPOTrainer runs 3-epoch training
  • Experience library grows (50-200 experiences)
  • Zero parameter updates (model frozen)
  • Checkpoint save/load works

Quality

  • All tests passing (100% coverage for new code)
  • No code duplication
  • Clear separation of concerns
  • Type-safe parameter handling
  • Comprehensive documentation

Performance

  • AIME24 accuracy ≥75% (target: 82.7%)
  • Training cost ≤$25 (target: $18)
  • Training time ≤8 hours (target: 6 hours)

Conclusion

Recommended Approach:

  1. Refactor First (Week 1)
  2. Create BaseRLHFTrainer
  3. Refactor parameter classes
  4. Resolve file conflicts

  5. Implement Continuous Learning GRPO (Week 2-3)

  6. Create ContinuousLearningGRPOTrainer (separate class)
  7. Implement domain modules
  8. Comprehensive testing

  9. Validate & Document (Week 4)

  10. End-to-end testing
  11. Benchmarking
  12. Documentation

Why This Approach: - ✅ Clean foundation before new features - ✅ Eliminate 400+ lines of duplication - ✅ Clear separation: GRPOTrainer vs ContinuousLearningGRPOTrainer - ✅ Type-safe, maintainable, extensible - ✅ 1 week cleanup investment → long-term maintainability

Estimated Timeline: 4 weeks total (1 week cleanup + 3 weeks implementation)


Next Step: Approve cleanup plan and proceed with Phase 1 (BaseRLHFTrainer creation)?