环境封装工具
概述
通用环境封装类,支持OpenAI Gym、多智能体环境、并行环境采样等功能,为强化学习算法提供统一的环境接口。
核心实现
import numpy as np
import gym
from collections import deque
import multiprocessing as mp
class GymWrapper:
def __init__(self, env, n_workers=4):
self.env = env
self.n_workers = n_workers
self.action_space = env.action_space
self.observation_space = env.observation_space
# 并行环境
self.envs = [env for _ in range(n_workers)]
def reset(self):
return self.env.reset()
def step(self, action):
obs, reward, done, info = self.env.step(action)
return obs, reward, done, info
def parallel_step(self, actions):
"""并行执行多个环境的step"""
results = []
with mp.Pool(self.n_workers) as pool:
results = pool.starmap(lambda env, a: env.step(a),
zip(self.envs, actions))
return results
def render(self):
return self.env.render()
def close(self):
for env in self.envs:
env.close()
@property
def spec(self):
return self.env.spec
@property
def unwrapped(self):
return self.env.unwrapped
class MultiAgentEnv:
"""多智能体环境封装"""
def __init__(self, env, n_agents=2):
self.env = env
self.n_agents = n_agents
self.agents = [GymWrapper(env) for _ in range(n_agents)]
def reset(self):
return [agent.reset() for agent in self.agents]
def step(self, actions):
results = []
for i, agent in enumerate(self.agents):
obs, reward, done, info = agent.step(actions[i])
results.append((obs, reward, done, info))
return results
def render(self):
return self.env.render()
class ParallelEnv:
"""并行环境采样工具"""
def __init__(self, envs):
self.envs = envs
self.n_envs = len(envs)
def sample_batch(self, batch_size):
"""从多个环境中并行采样"""
obs_batch = []
action_batch = []
reward_batch = []
done_batch = []
for _ in range(batch_size):
env_idx = np.random.randint(0, self.n_envs)
obs = self.envs[env_idx].reset()
action = self.envs[env_idx].action_space.sample()
obs_next, reward, done, _ = self.envs[env_idx].step(action)
obs_batch.append(obs)
action_batch.append(action)
reward_batch.append(reward)
done_batch.append(done)
return (np.array(obs_batch), np.array(action_batch),
np.array(reward_batch), np.array(done_batch))
# 使用示例
if __name__ == "__main__":
import gym
# 创建环境
env = gym.make("CartPole-v1")
# 包装环境
wrapped = GymWrapper(env, n_workers=4)
# 采样batch
obs, actions, rewards, dones = wrapped.sample_batch(32)
print(f"采样形状: obs={obs.shape}, actions={actions.shape}")
print(f"奖励范围: [{rewards.min():.2f}, {rewards.max():.2f}]")
print(f"完成率: {dones.mean():.2f}")