gym建立環境、自定義gym環境

Wei_Xiong發表於2024-08-15

環境:half_cheetah.py

from os import path

import numpy as np

from gymnasium import utils
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.spaces import Box

DEFAULT_CAMERA_CONFIG = {
    "distance": 4.0,
}


class MOHalfCheetahEnv(MujocoEnv, utils.EzPickle):
    metadata = {
        "render_modes": [
            "human",
            "rgb_array",
            "depth_array",
        ],
        "render_fps": 20,
    }

    def __init__(
            self,
            **kwargs,
    ):
        utils.EzPickle.__init__(
            self,
            **kwargs,
        )

        # 計算 observation_space
        observation_space = Box(
            low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
        )

        # init
        MujocoEnv.__init__(
            self,
            "half_cheetah.xml", # 直接使用庫裡面的
            5,
            observation_space=observation_space,
            default_camera_config=DEFAULT_CAMERA_CONFIG,
            **kwargs,
        )

        # mo相關屬性
        self.reward_space = Box(low=-np.inf, high=np.inf, shape=(2,))
        self.reward_dim = 2

    def step(self, action):
        # pgmorl pdmorl 直接在這裡對action進行裁剪動作
        action = np.clip(action, -1.0, 1.0)

        # 計算速度
        x_position_before = self.data.qpos[0]
        self.do_simulation(action, self.frame_skip)
        x_position_after = self.data.qpos[0]
        x_velocity = (x_position_after - x_position_before) / self.dt

        # observation
        observation = self._get_obs()

        # reward
        alive_bonus = 1
        reward_run = min(4.0, x_velocity) + alive_bonus
        reward_energy = 4.0 - 1.0 * np.square(action).sum() + alive_bonus
        vec_reward = np.array([reward_run, reward_energy], dtype=np.float32)

        # terminated truncated
        ang = self.data.qpos[2]
        # terminated = not (abs(ang) < np.deg2rad(50))  # 終止 pgmorl pdmorl有終止
        terminated = False  # 終止 pgmorl pdmorl有終止
        truncated = False  # 截斷

        # info
        info = {}

        # render
        if self.render_mode == "human":
            self.render()

        return observation, vec_reward, terminated, truncated, info

    def _get_obs(self):
        position = self.data.qpos.flat.copy()
        velocity = self.data.qvel.flat.copy()

        position = position[1:]  # obs 維度17

        observation = np.concatenate((position, velocity)).ravel()
        return observation

    def reset_model(self):
        qpos = self.init_qpos + self.np_random.uniform(
            low=-0.1, high=0.1, size=self.model.nq
        )
        qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
        self.set_state(qpos, qvel)
        return self._get_obs()

註冊、不檢查環境

from gymnasium.envs.registration import register
import mo_gymnasium as mo_gym
from half_cheetah import MOHalfCheetahEnv

register(
    id="wx-half-v1",
    entry_point=MOHalfCheetahEnv,
    max_episode_steps=500,
)

if __name__ == '__main__':
    import gymnasium as gym

    # env = MOHalfCheetahEnv(render_mode="human")
    # env = MOHalfCheetahEnv()
    # env = mo_gym.make('mo-halfcheetah-v4')  # 無done 1000次
    # env = gym.make("HalfCheetah-v4") # 無done 1000次
    env = gym.make("wx-half-v1", disable_env_checker=True)

    done = False
    obv, info = env.reset(seed=5)
    env.action_space.seed(5)
    env.observation_space.seed(5)

    print(type(env))

    steps = 0
    while not done:
        action = env.action_space.sample()
        obv, r, d1, d2, _ = env.step(action)
        # print(r)
        done = d1 or d2
        steps += 1
        print(steps)

    print(steps)

相關文章