CODEBASE
← 返回列表

环境封装工具

Gym Environment Wrapper | 通用环境支持

概述

通用环境封装类,支持OpenAI Gym、多智能体环境、并行环境采样等功能,为强化学习算法提供统一的环境接口。

核心实现

import numpy as np
import gym
from collections import deque
import multiprocessing as mp

class GymWrapper:
    def __init__(self, env, n_workers=4):
        self.env = env
        self.n_workers = n_workers
        self.action_space = env.action_space
        self.observation_space = env.observation_space
        
        # 并行环境
        self.envs = [env for _ in range(n_workers)]
        
    def reset(self):
        return self.env.reset()
    
    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        return obs, reward, done, info
    
    def parallel_step(self, actions):
        """并行执行多个环境的step"""
        results = []
        with mp.Pool(self.n_workers) as pool:
            results = pool.starmap(lambda env, a: env.step(a), 
                                  zip(self.envs, actions))
        return results
    
    def render(self):
        return self.env.render()
    
    def close(self):
        for env in self.envs:
            env.close()
    
    @property
    def spec(self):
        return self.env.spec
    
    @property
    def unwrapped(self):
        return self.env.unwrapped

class MultiAgentEnv:
    """多智能体环境封装"""
    def __init__(self, env, n_agents=2):
        self.env = env
        self.n_agents = n_agents
        self.agents = [GymWrapper(env) for _ in range(n_agents)]
    
    def reset(self):
        return [agent.reset() for agent in self.agents]
    
    def step(self, actions):
        results = []
        for i, agent in enumerate(self.agents):
            obs, reward, done, info = agent.step(actions[i])
            results.append((obs, reward, done, info))
        return results
    
    def render(self):
        return self.env.render()

class ParallelEnv:
    """并行环境采样工具"""
    def __init__(self, envs):
        self.envs = envs
        self.n_envs = len(envs)
    
    def sample_batch(self, batch_size):
        """从多个环境中并行采样"""
        obs_batch = []
        action_batch = []
        reward_batch = []
        done_batch = []
        
        for _ in range(batch_size):
            env_idx = np.random.randint(0, self.n_envs)
            obs = self.envs[env_idx].reset()
            action = self.envs[env_idx].action_space.sample()
            obs_next, reward, done, _ = self.envs[env_idx].step(action)
            
            obs_batch.append(obs)
            action_batch.append(action)
            reward_batch.append(reward)
            done_batch.append(done)
        
        return (np.array(obs_batch), np.array(action_batch), 
                np.array(reward_batch), np.array(done_batch))

# 使用示例
if __name__ == "__main__":
    import gym
    
    # 创建环境
    env = gym.make("CartPole-v1")
    
    # 包装环境
    wrapped = GymWrapper(env, n_workers=4)
    
    # 采样batch
    obs, actions, rewards, dones = wrapped.sample_batch(32)
    
    print(f"采样形状: obs={obs.shape}, actions={actions.shape}")
    print(f"奖励范围: [{rewards.min():.2f}, {rewards.max():.2f}]")
    print(f"完成率: {dones.mean():.2f}")
    
← 返回列表