GithubHelp home page GithubHelp logo

starlight0798 / drl-ts Goto Github PK

View Code? Open in Web Editor NEW
2.0 2.0 0.0 115 KB

基于Tianshou框架的强化学习DRL实验探索(在gym,pettingzoo,atari等环境)

License: MIT License

Python 99.74% Batchfile 0.12% Shell 0.14%
reinforcement-learning tianshou pytorch

drl-ts's Introduction

简介

这是一个基于Tianshou框架的深度强化学习(DRL)实验项目,适用于Gymnasium、Pettingzoo和Atari等环境。该项目用于个人学习和研究

安装

Python版本要求

请使用Python 3.11版本,不要使用3.10或3.12。

最好安装Anaconda,使用如下命令创建和激活环境:

conda create -n drl python=3.11
conda activate drl

安装Tianshou和依赖项

  1. 克隆Tianshou仓库并安装:

    git clone https://github.com/thu-ml/tianshou.git
    cd tianshou
    conda activate drl  
    pip install .
  2. 安装基础依赖:

    pip install -r requirements-base.txt
  3. 安装其他依赖:

    pip install -r requirements.txt

使用

本项目提供了一些示例代码,可以帮助你快速开始使用Tianshou框架进行DRL实验。

由于Tianshou在算法、训练方法等方面比较完善,目前我主要试验不同神经网络的开发,在tianshou框架下不同算法的训练效率以及水准等。

读者可以参照/utils/model.py,尝试以下神经网络进行特征提取

# MLP Concat
class PSCN(nn.Module):
    def __init__(self, input_dim, output_dim, linear=nn.Linear):
        super(PSCN, self).__init__()
        assert output_dim >= 32 and output_dim % 8 == 0, "output_dim must be >= 32 and divisible by 8"
        self.hidden_dim = output_dim
        self.fc1 = MLP([input_dim, self.hidden_dim], last_act=True, linear=linear)
        self.fc2 = MLP([self.hidden_dim // 2, self.hidden_dim // 2], last_act=True, linear=linear)
        self.fc3 = MLP([self.hidden_dim // 4, self.hidden_dim // 4], last_act=True, linear=linear)
        self.fc4 = MLP([self.hidden_dim // 8, self.hidden_dim // 8], last_act=True, linear=linear)

    def forward(self, x):
        _shape = x.shape
        if len(_shape) > 2:
            x = x.view(-1, _shape[-1])
        
        x = self.fc1(x)

        x1 = x[:, :self.hidden_dim // 2]
        x = x[:, self.hidden_dim // 2:]
        x = self.fc2(x)

        x2 = x[:, :self.hidden_dim // 4]
        x = x[:, self.hidden_dim // 4:]
        x = self.fc3(x)

        x3 = x[:, :self.hidden_dim // 8]
        x = x[:, self.hidden_dim // 8:]
        x4 = self.fc4(x)

        out = torch.cat([x1, x2, x3, x4], dim=1)
        
        if len(_shape) > 2:
            out = out.view(_shape[0], _shape[1], -1)
        return out


# 稠密层(单层)
class DenseLayer(nn.Module):
    def __init__(self, in_features, growth_rate):
        super(DenseLayer, self).__init__()
        self.fc = MLP([in_features, growth_rate], last_act=True)

    def forward(self, x):
        return torch.cat([x, self.fc(x)], dim=-1)


# 稠密层
class DenseBlock(nn.Module):
    def __init__(self, in_features, growth_rate, num_layers):
        super(DenseBlock, self).__init__()
        layers = []
        for i in range(num_layers):
            layers.append(DenseLayer(in_features + i * growth_rate, growth_rate))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

贡献

欢迎提交问题(Issues)和拉取请求(Pull Requests)以改进此项目。请确保在提交之前阅读并遵循贡献指南。

协议

本项目使用MIT协议。请参阅LICENSE文件以获取更多信息。

drl-ts's People

Contributors

starlight0798 avatar

Stargazers

XYC avatar  avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.