欧美三区_成人在线免费观看视频_欧美极品少妇xxxxⅹ免费视频_a级毛片免费播放_鲁一鲁中文字幕久久_亚洲一级特黄

C++強化學習通過Python bindings接OpenAI Gym

系統 1608 0

OpenAI gym是強化學習最為流行的實驗環境。某種程度上,其接口已經成為了標準。一方面,很多算法實現都是基于gym開發;另一方面,新的場景也會封裝成gym接口。經過這樣一層抽象,算法與實驗環境充分解耦隔離,可以方便地自由組合。但gym是python的接口,如果想用C++實現強化學習算法,則無法直接與gym相接。一種方案是跨進程:一個進程運行python環境,另一個進程運行強化學習算法,與環境交互數據經過序列化和反序列化通過IPC進行通信。另一種是單進程方案:gym和強化學習算法跑在同一進程,通過python binding來連接。本文嘗試通過pybind11來橋接,從而實現在同一進程中gym與強化學習算法通信的目的。

C++機器學習框架采用PyTorch提供的Libtorch。因為在目前主流的幾個訓練框架中,在C++版本上相比下它是算支持地比較好的,安裝也算方便。安裝流程見INSTALLING C++ DISTRIBUTIONS OF PYTORCH。官方提供的sample中提供了一個REINFORCE算法(一種雖然比較古老但很經典的RL算法)的例子reinforce.py,我們就先以它為例。

首先來看python部分,參考原sample做一些小改動。將強化學習算法調用接口抽象在RLWrapper。這個類后面會binding到C++。初始化時傳入gym環境的狀態和動態空間描述,reset()函數通知環境重置并傳入初始狀態,act()函數根據當前狀態根據策略給出動作,update()函數進行策略函數參數學習。

            
              
                .
              
              
                .
              
              
                .
              
              
                def
              
              
                state_space_desc
              
              
                (
              
              space
              
                )
              
              
                :
              
              
                if
              
              
                isinstance
              
              
                (
              
              space
              
                ,
              
               gym
              
                .
              
              spaces
              
                .
              
              Box
              
                )
              
              
                :
              
              
                assert
              
              
                (
              
              
                type
              
              
                (
              
              space
              
                .
              
              shape
              
                )
              
              
                ==
              
              
                tuple
              
              
                )
              
              
                return
              
              
                dict
              
              
                (
              
              stype
              
                =
              
              
                'Box'
              
              
                ,
              
               dtype
              
                =
              
              
                str
              
              
                (
              
              space
              
                .
              
              dtype
              
                )
              
              
                ,
              
               shape
              
                =
              
              space
              
                .
              
              shape
              
                )
              
              
                else
              
              
                :
              
              
                raise
              
               NotImplementedError
              
                (
              
              
                'unknown state space {}'
              
              
                .
              
              
                format
              
              
                (
              
              space
              
                )
              
              
                )
              
              
                def
              
              
                action_space_desc
              
              
                (
              
              space
              
                )
              
              
                :
              
              
                if
              
              
                isinstance
              
              
                (
              
              space
              
                ,
              
               gym
              
                .
              
              spaces
              
                .
              
              Discrete
              
                )
              
              
                :
              
              
                return
              
              
                dict
              
              
                (
              
              stype
              
                =
              
              
                'Discrete'
              
              
                ,
              
               dtype
              
                =
              
              
                str
              
              
                (
              
              space
              
                .
              
              dtype
              
                )
              
              
                ,
              
               shape
              
                =
              
              
                (
              
              space
              
                .
              
              n
              
                ,
              
              
                )
              
              
                )
              
              
                else
              
              
                :
              
              
                raise
              
               NotImplementedError
              
                (
              
              
                'unknown action space {}'
              
              
                .
              
              
                format
              
              
                (
              
              space
              
                )
              
              
                )
              
              
                def
              
              
                main
              
              
                (
              
              args
              
                )
              
              
                :
              
              
    env 
              
                =
              
               gym
              
                .
              
              make
              
                (
              
              args
              
                .
              
              env
              
                )
              
              
    env
              
                .
              
              seed
              
                (
              
              args
              
                .
              
              seed
              
                )
              
              

    agent 
              
                =
              
               nativerl
              
                .
              
              RLWrapper
              
                (
              
              state_space_desc
              
                (
              
              env
              
                .
              
              observation_space
              
                )
              
              
                ,
              
              
            action_space_desc
              
                (
              
              env
              
                .
              
              action_space
              
                )
              
              
                )
              
              

    running_reward 
              
                =
              
              
                10
              
              
                for
              
               i 
              
                in
              
              
                range
              
              
                (
              
              args
              
                .
              
              epoch
              
                )
              
              
                :
              
              
        obs 
              
                =
              
               env
              
                .
              
              reset
              
                (
              
              
                )
              
              
        ep_reward 
              
                =
              
              
                0
              
              
        agent
              
                .
              
              reset
              
                (
              
              obs
              
                )
              
              
                for
              
               t 
              
                in
              
              
                range
              
              
                (
              
              
                1
              
              
                ,
              
               args
              
                .
              
              step
              
                )
              
              
                :
              
              
                if
              
               args
              
                .
              
              render
              
                :
              
              
                env
              
                .
              
              render
              
                (
              
              
                )
              
              
            action 
              
                =
              
               agent
              
                .
              
              act
              
                (
              
              obs
              
                )
              
              
            obs
              
                ,
              
               reward
              
                ,
              
               done
              
                ,
              
               info 
              
                =
              
               env
              
                .
              
              step
              
                (
              
              action
              
                )
              
              
            agent
              
                .
              
              update
              
                (
              
              reward
              
                ,
              
               done
              
                )
              
              
            ep_reward 
              
                +=
              
               reward
            
              
                if
              
               done
              
                :
              
              
                break
              
              

        running_reward 
              
                =
              
              
                0.05
              
              
                *
              
               ep_reward 
              
                +
              
              
                (
              
              
                1
              
              
                -
              
              
                0.05
              
              
                )
              
              
                *
              
               running_reward
        agent
              
                .
              
              episode_finish
              
                (
              
              
                )
              
              
                if
              
               i 
              
                %
              
               args
              
                .
              
              log_itv 
              
                ==
              
              
                0
              
              
                :
              
              
                print
              
              
                (
              
              
                "Episode {}\t Last reward: {:.2f}\t step: {}\t Average reward: {:.2f}"
              
              
                .
              
              
                format
              
              
                (
              
              i
              
                ,
              
               ep_reward
              
                ,
              
               t
              
                ,
              
               running_reward
              
                )
              
              
                )
              
              
                if
              
               env
              
                .
              
              spec
              
                .
              
              reward_threshold 
              
                and
              
               running_reward 
              
                >
              
               env
              
                .
              
              spec
              
                .
              
              reward_threshold
              
                :
              
              
                print
              
              
                (
              
              
                "Solved. Running reward: {}, Last reward: {}"
              
              
                .
              
              
                format
              
              
                (
              
              running_reward
              
                ,
              
               t
              
                )
              
              
                )
              
              
                break
              
              

    env
              
                .
              
              close
              
                (
              
              
                )
              
            
          

然后是RLWrapper的python binding部分。這里主要是將python的對象轉為C++的數據結構。

            
              
                .
              
              
                .
              
              
                .
              
              
                namespace
              
               py 
              
                =
              
               pybind11
              
                ;
              
              
                class
              
              
                RLWrapper
              
              
                {
              
              
                public
              
              
                :
              
              
                RLWrapper
              
              
                (
              
              
                const
              
               py
              
                ::
              
              dict
              
                &
              
               state_space
              
                ,
              
              
                const
              
               py
              
                ::
              
              dict
              
                &
              
               action_space
              
                )
              
              
                {
              
              
        spdlog
              
                ::
              
              
                set_level
              
              
                (
              
              spdlog
              
                ::
              
              level
              
                ::
              
              info
              
                )
              
              
                ;
              
              
        torch
              
                ::
              
              
                manual_seed
              
              
                (
              
              nrl
              
                ::
              
              kSeed
              
                )
              
              
                ;
              
              

        nrl
              
                ::
              
              SpaceDesc ss
              
                ;
              
              
        nrl
              
                ::
              
              SpaceDesc as
              
                ;
              
              
        ss
              
                .
              
              stype 
              
                =
              
               py
              
                ::
              
              cast
              
                <
              
              std
              
                ::
              
              string
              
                >
              
              
                (
              
              state_space
              
                [
              
              
                "stype"
              
              
                ]
              
              
                )
              
              
                ;
              
              
        as
              
                .
              
              stype 
              
                =
              
               py
              
                ::
              
              cast
              
                <
              
              std
              
                ::
              
              string
              
                >
              
              
                (
              
              action_space
              
                [
              
              
                "stype"
              
              
                ]
              
              
                )
              
              
                ;
              
              
        ss
              
                .
              
              dtype 
              
                =
              
               py
              
                ::
              
              cast
              
                <
              
              std
              
                ::
              
              string
              
                >
              
              
                (
              
              state_space
              
                [
              
              
                "dtype"
              
              
                ]
              
              
                )
              
              
                ;
              
              
        as
              
                .
              
              dtype 
              
                =
              
               py
              
                ::
              
              cast
              
                <
              
              std
              
                ::
              
              string
              
                >
              
              
                (
              
              action_space
              
                [
              
              
                "dtype"
              
              
                ]
              
              
                )
              
              
                ;
              
              

        py
              
                ::
              
              tuple shape
              
                ;
              
              
        shape 
              
                =
              
               py
              
                ::
              
              cast
              
                <
              
              py
              
                ::
              
              tuple
              
                >
              
              
                (
              
              state_space
              
                [
              
              
                "shape"
              
              
                ]
              
              
                )
              
              
                ;
              
              
                for
              
              
                (
              
              
                const
              
              
                auto
              
              
                &
              
               item 
              
                :
              
               shape
              
                )
              
              
                {
              
              
            ss
              
                .
              
              shape
              
                .
              
              
                push_back
              
              
                (
              
              py
              
                ::
              
              cast
              
                <
              
              
                int64_t
              
              
                >
              
              
                (
              
              item
              
                )
              
              
                )
              
              
                ;
              
              
                }
              
              
        shape 
              
                =
              
               py
              
                ::
              
              cast
              
                <
              
              py
              
                ::
              
              tuple
              
                >
              
              
                (
              
              action_space
              
                [
              
              
                "shape"
              
              
                ]
              
              
                )
              
              
                ;
              
              
                for
              
              
                (
              
              
                const
              
              
                auto
              
              
                &
              
               item 
              
                :
              
               shape
              
                )
              
              
                {
              
              
            as
              
                .
              
              shape
              
                .
              
              
                push_back
              
              
                (
              
              py
              
                ::
              
              cast
              
                <
              
              
                int64_t
              
              
                >
              
              
                (
              
              item
              
                )
              
              
                )
              
              
                ;
              
              
                }
              
              

        mStateSpaceDesc 
              
                =
              
               ss
              
                ;
              
              
        mActionSpaceDesc 
              
                =
              
               as
              
                ;
              
              
        mAgent 
              
                =
              
               std
              
                ::
              
              make_shared
              
                <
              
              nrl
              
                ::
              
              Reinforce
              
                >
              
              
                (
              
              ss
              
                ,
              
               as
              
                )
              
              
                ;
              
              
                }
              
              
                void
              
              
                reset
              
              
                (
              
              py
              
                ::
              
              array_t
              
                <
              
              
                float
              
              
                ,
              
               py
              
                ::
              
              array
              
                ::
              
              c_style 
              
                |
              
               py
              
                ::
              
              array
              
                ::
              
              forcecast
              
                >
              
               state
              
                )
              
              
                {
              
              
        py
              
                ::
              
              buffer_info buf 
              
                =
              
               state
              
                .
              
              
                request
              
              
                (
              
              
                )
              
              
                ;
              
              
                float
              
              
                *
              
               pbuf 
              
                =
              
              
                static_cast
              
              
                <
              
              
                float
              
              
                *
              
              
                >
              
              
                (
              
              buf
              
                .
              
              ptr
              
                )
              
              
                ;
              
              
                assert
              
              
                (
              
              buf
              
                .
              
              shape 
              
                ==
              
               mStateSpaceDesc
              
                .
              
              shape
              
                )
              
              
                ;
              
              
        mAgent
              
                -
              
              
                >
              
              
                reset
              
              
                (
              
              nrl
              
                ::
              
              Blob
              
                {
              
              pbuf
              
                ,
              
               mStateSpaceDesc
              
                .
              
              shape
              
                }
              
              
                )
              
              
                ;
              
              
                }
              
              

    py
              
                ::
              
              object 
              
                act
              
              
                (
              
              py
              
                ::
              
              array_t
              
                <
              
              
                float
              
              
                ,
              
               py
              
                ::
              
              array
              
                ::
              
              c_style 
              
                |
              
               py
              
                ::
              
              array
              
                ::
              
              forcecast
              
                >
              
               state
              
                )
              
              
                {
              
              
        py
              
                ::
              
              buffer_info buf 
              
                =
              
               state
              
                .
              
              
                request
              
              
                (
              
              
                )
              
              
                ;
              
              
                float
              
              
                *
              
               pbuf 
              
                =
              
              
                static_cast
              
              
                <
              
              
                float
              
              
                *
              
              
                >
              
              
                (
              
              buf
              
                .
              
              ptr
              
                )
              
              
                ;
              
              
                assert
              
              
                (
              
              buf
              
                .
              
              shape 
              
                ==
              
               mStateSpaceDesc
              
                .
              
              shape
              
                )
              
              
                ;
              
              
        torch
              
                ::
              
              Tensor action 
              
                =
              
               mAgent
              
                -
              
              
                >
              
              
                act
              
              
                (
              
              nrl
              
                ::
              
              Blob
              
                {
              
              pbuf
              
                ,
              
               mStateSpaceDesc
              
                .
              
              shape
              
                }
              
              
                )
              
              
                .
              
              
                contiguous
              
              
                (
              
              
                )
              
              
                .
              
              
                cpu
              
              
                (
              
              
                )
              
              
                ;
              
              
                return
              
               py
              
                ::
              
              
                int_
              
              
                (
              
              action
              
                .
              
              item
              
                <
              
              
                long
              
              
                >
              
              
                (
              
              
                )
              
              
                )
              
              
                ;
              
              
                }
              
              
                void
              
              
                update
              
              
                (
              
              
                float
              
               reward
              
                ,
              
              
                bool
              
               done
              
                )
              
              
                {
              
              
        mAgent
              
                -
              
              
                >
              
              
                update
              
              
                (
              
              reward
              
                ,
              
               done
              
                )
              
              
                ;
              
              
                }
              
              
                void
              
              
                episode_finish
              
              
                (
              
              
                )
              
              
                {
              
              
        spdlog
              
                ::
              
              
                trace
              
              
                (
              
              
                "{}"
              
              
                ,
              
              
                __func__
              
              
                )
              
              
                ;
              
              
        mAgent
              
                -
              
              
                >
              
              
                onEpisodeFinished
              
              
                (
              
              
                )
              
              
                ;
              
              
                }
              
              
                ~
              
              
                RLWrapper
              
              
                (
              
              
                )
              
              
                {
              
              
                }
              
              
                private
              
              
                :
              
              
    nrl
              
                ::
              
              SpaceDesc mStateSpaceDesc
              
                ;
              
              
    nrl
              
                ::
              
              SpaceDesc mActionSpaceDesc
              
                ;
              
              
    std
              
                ::
              
              shared_ptr
              
                <
              
              nrl
              
                ::
              
              RLBase
              
                >
              
               mAgent
              
                ;
              
              
                }
              
              
                ;
              
              
                PYBIND11_MODULE
              
              
                (
              
              nativerl
              
                ,
              
               m
              
                )
              
              
                {
              
              
    py
              
                ::
              
              class_
              
                <
              
              RLWrapper
              
                >
              
              
                (
              
              m
              
                ,
              
              
                "RLWrapper"
              
              
                )
              
              
                .
              
              
                def
              
              
                (
              
              py
              
                ::
              
              init
              
                <
              
              
                const
              
               py
              
                ::
              
              dict 
              
                &
              
              
                ,
              
              
                const
              
               py
              
                ::
              
              dict 
              
                &
              
              
                >
              
              
                (
              
              
                )
              
              
                )
              
              
                .
              
              
                def
              
              
                (
              
              
                "reset"
              
              
                ,
              
              
                &
              
              RLWrapper
              
                ::
              
              reset
              
                )
              
              
                .
              
              
                def
              
              
                (
              
              
                "episode_finish"
              
              
                ,
              
              
                &
              
              RLWrapper
              
                ::
              
              episode_finish
              
                )
              
              
                .
              
              
                def
              
              
                (
              
              
                "act"
              
              
                ,
              
              
                &
              
              RLWrapper
              
                ::
              
              act
              
                )
              
              
                .
              
              
                def
              
              
                (
              
              
                "update"
              
              
                ,
              
              
                &
              
              RLWrapper
              
                ::
              
              update
              
                )
              
              
                ;
              
              
                }
              
            
          

可以說這是python和C++的glue層。主要的工作我們放到RLBase類中。它是一個抽象類,定義了幾個強化學習的基本接口。我們將REINFORCE算法實現在其繼承類Reinforce中:

            
              
                .
              
              
                .
              
              
                .
              
              
                class
              
              
                Reinforce
              
              
                :
              
              
                public
              
               RLBase 
              
                {
              
              
                public
              
              
                :
              
              
                Reinforce
              
              
                (
              
              
                const
              
               SpaceDesc
              
                &
              
               ss
              
                ,
              
              
                const
              
               SpaceDesc
              
                &
              
               as
              
                )
              
              
                :
              
              
                mPolicy
              
              
                (
              
              std
              
                ::
              
              make_shared
              
                <
              
              Policy
              
                >
              
              
                (
              
              ss
              
                ,
              
               as
              
                ,
              
               mDevice
              
                )
              
              
                )
              
              
                {
              
              
        mPolicy
              
                -
              
              
                >
              
              
                to
              
              
                (
              
              mDevice
              
                )
              
              
                ;
              
              

        mRewards 
              
                =
              
               torch
              
                ::
              
              
                zeros
              
              
                (
              
              
                {
              
              mCapacity
              
                }
              
              
                ,
              
               torch
              
                ::
              
              
                TensorOptions
              
              
                (
              
              mDevice
              
                )
              
              
                )
              
              
                ;
              
              
        mReturns 
              
                =
              
               torch
              
                ::
              
              
                zeros
              
              
                (
              
              
                {
              
              mCapacity
              
                }
              
              
                ,
              
               torch
              
                ::
              
              
                TensorOptions
              
              
                (
              
              mDevice
              
                )
              
              
                )
              
              
                ;
              
              

        mOptimizer 
              
                =
              
               std
              
                ::
              
              make_shared
              
                <
              
              torch
              
                ::
              
              optim
              
                ::
              
              Adam
              
                >
              
              
                (
              
              mPolicy
              
                -
              
              
                >
              
              
                parameters
              
              
                (
              
              
                )
              
              
                ,
              
               
                torch
              
                ::
              
              optim
              
                ::
              
              
                AdamOptions
              
              
                (
              
              mInitLR
              
                )
              
              
                )
              
              
                ;
              
              
                }
              
              
                virtual
              
               torch
              
                ::
              
              Tensor 
              
                act
              
              
                (
              
              
                const
              
               Blob
              
                &
              
               s
              
                )
              
               override 
              
                {
              
              
                auto
              
               state 
              
                =
              
               torch
              
                ::
              
              
                from_blob
              
              
                (
              
              s
              
                .
              
              pbuf
              
                ,
              
               s
              
                .
              
              shape
              
                )
              
              
                .
              
              
                unsqueeze
              
              
                (
              
              
                0
              
              
                )
              
              
                .
              
              
                to
              
              
                (
              
              mDevice
              
                )
              
              
                ;
              
              
        torch
              
                ::
              
              Tensor action
              
                ;
              
              
        torch
              
                ::
              
              Tensor logProb
              
                ;
              
              
        std
              
                ::
              
              
                tie
              
              
                (
              
              action
              
                ,
              
               logProb
              
                )
              
              
                =
              
               mPolicy
              
                -
              
              
                >
              
              
                act
              
              
                (
              
              state
              
                )
              
              
                ;
              
              
        mLogProbs
              
                .
              
              
                push_back
              
              
                (
              
              logProb
              
                )
              
              
                ;
              
              
                return
              
               action
              
                ;
              
              
                }
              
              
                void
              
              
                update
              
              
                (
              
              
                float
              
               r
              
                ,
              
              
                __attribute__
              
              
                (
              
              
                (
              
              unused
              
                )
              
              
                )
              
              
                bool
              
               done
              
                )
              
              
                {
              
              
        mRewards
              
                [
              
              mSize
              
                ++
              
              
                ]
              
              
                =
              
               r
              
                ;
              
              
                if
              
              
                (
              
              mSize 
              
                >=
              
               mCapacity
              
                )
              
              
                {
              
              
            spdlog
              
                ::
              
              
                info
              
              
                (
              
              
                "buffer has been full, call train()"
              
              
                )
              
              
                ;
              
              
                train
              
              
                (
              
              
                )
              
              
                ;
              
              
                }
              
              
                }
              
              
                virtual
              
              
                void
              
              
                onEpisodeFinished
              
              
                (
              
              
                )
              
               override 
              
                {
              
              
                train
              
              
                (
              
              
                )
              
              
                ;
              
              
                }
              
              
                private
              
              
                :
              
              
                void
              
              
                train
              
              
                (
              
              
                )
              
              
                {
              
              
        spdlog
              
                ::
              
              
                trace
              
              
                (
              
              
                "{}: buffer size = {}"
              
              
                ,
              
              
                __func__
              
              
                ,
              
               mSize
              
                )
              
              
                ;
              
              
                for
              
              
                (
              
              
                auto
              
               i 
              
                =
              
               mSize 
              
                -
              
              
                1
              
              
                ;
              
               i 
              
                >=
              
              
                0
              
              
                ;
              
              
                --
              
              i
              
                )
              
              
                {
              
              
                if
              
              
                (
              
              i 
              
                ==
              
              
                (
              
              mSize 
              
                -
              
              
                1
              
              
                )
              
              
                )
              
              
                {
              
              
                mReturns
              
                [
              
              i
              
                ]
              
              
                =
              
               mRewards
              
                [
              
              i
              
                ]
              
              
                ;
              
              
                }
              
              
                else
              
              
                {
              
              
                mReturns
              
                [
              
              i
              
                ]
              
              
                =
              
               mReturns
              
                [
              
              i 
              
                +
              
              
                1
              
              
                ]
              
              
                *
              
               mGamma 
              
                +
              
               mRewards
              
                [
              
              i
              
                ]
              
              
                ;
              
              
                }
              
              
                }
              
              
                auto
              
               returns 
              
                =
              
               mReturns
              
                .
              
              
                slice
              
              
                (
              
              
                0
              
              
                ,
              
              
                0
              
              
                ,
              
               mSize
              
                )
              
              
                ;
              
              
        returns 
              
                =
              
              
                (
              
              returns 
              
                -
              
               returns
              
                .
              
              
                mean
              
              
                (
              
              
                )
              
              
                )
              
              
                /
              
              
                (
              
              returns
              
                .
              
              
                std
              
              
                (
              
              
                )
              
              
                +
              
               kEps
              
                )
              
              
                ;
              
              
                auto
              
               logprobs 
              
                =
              
               torch
              
                ::
              
              
                cat
              
              
                (
              
              mLogProbs
              
                )
              
              
                ;
              
              

        mOptimizer
              
                -
              
              
                >
              
              
                zero_grad
              
              
                (
              
              
                )
              
              
                ;
              
              
                auto
              
               policy_loss 
              
                =
              
              
                -
              
              
                (
              
              logprobs 
              
                *
              
               returns
              
                )
              
              
                .
              
              
                sum
              
              
                (
              
              
                )
              
              
                ;
              
              
        policy_loss
              
                .
              
              
                backward
              
              
                (
              
              
                )
              
              
                ;
              
              
        mOptimizer
              
                -
              
              
                >
              
              
                step
              
              
                (
              
              
                )
              
              
                ;
              
              

        mLogProbs
              
                .
              
              
                clear
              
              
                (
              
              
                )
              
              
                ;
              
              
        mSize 
              
                =
              
              
                0
              
              
                ;
              
              
                ++
              
              mCount
              
                ;
              
              
        spdlog
              
                ::
              
              
                debug
              
              
                (
              
              
                "{} : episode {}: loss = {}, accumulated reward = {}"
              
              
                ,
              
              
                __func__
              
              
                ,
              
               mCount
              
                ,
              
               policy_loss
              
                .
              
              item
              
                <
              
              
                float
              
              
                >
              
              
                (
              
              
                )
              
              
                ,
              
               mRewards
              
                .
              
              
                sum
              
              
                (
              
              
                )
              
              
                .
              
              item
              
                <
              
              
                float
              
              
                >
              
              
                (
              
              
                )
              
              
                )
              
              
                ;
              
              
                }
              
              

    std
              
                ::
              
              shared_ptr
              
                <
              
              Policy
              
                >
              
               mPolicy
              
                ;
              
              

    torch
              
                ::
              
              Tensor mRewards
              
                ;
              
              
    std
              
                ::
              
              vector
              
                <
              
              torch
              
                ::
              
              Tensor
              
                >
              
               mLogProbs
              
                ;
              
              
    torch
              
                ::
              
              Tensor mReturns
              
                ;
              
              
                int32_t
              
               mSize
              
                {
              
              
                0
              
              
                }
              
              
                ;
              
              
                int32_t
              
               mCapacity
              
                {
              
              kExpBufferCap
              
                }
              
              
                ;
              
              

    std
              
                ::
              
              shared_ptr
              
                <
              
              torch
              
                ::
              
              optim
              
                ::
              
              Adam
              
                >
              
               mOptimizer
              
                ;
              
              
                uint32_t
              
               mCount
              
                {
              
              
                0
              
              
                }
              
              
                ;
              
              
                float
              
               mGamma
              
                {
              
              
                0.99
              
              
                }
              
              
                ;
              
              
                float
              
               mInitLR
              
                {
              
              
                1e-2
              
              
                }
              
              
                ;
              
              
                }
              
              
                ;
              
            
          

Sample中的場景為CartPole,場景比較簡單,因此其中的策略函數實現為簡單的MLP。更為復雜的場景可以替換為復雜的網絡結構。

            
              
                .
              
              
                .
              
              
                .
              
              
                class
              
              
                Net
              
              
                :
              
              
                public
              
               nn
              
                ::
              
              Module 
              
                {
              
              
                public
              
              
                :
              
              
                virtual
              
               std
              
                ::
              
              tuple
              
                <
              
              Tensor
              
                ,
              
               Tensor
              
                >
              
              
                forward
              
              
                (
              
              Tensor x
              
                )
              
              
                =
              
              
                0
              
              
                ;
              
              
                virtual
              
              
                ~
              
              
                Net
              
              
                (
              
              
                )
              
              
                =
              
              
                default
              
              
                ;
              
              
                }
              
              
                ;
              
              
                class
              
              
                MLP
              
              
                :
              
              
                public
              
               Net 
              
                {
              
              
                public
              
              
                :
              
              
                MLP
              
              
                (
              
              
                int64_t
              
               inputSize
              
                ,
              
              
                int64_t
              
               actionNum
              
                )
              
              
                {
              
              
        mFC1 
              
                =
              
              
                register_module
              
              
                (
              
              
                "fc1"
              
              
                ,
              
               nn
              
                ::
              
              
                Linear
              
              
                (
              
              inputSize
              
                ,
              
               mHiddenSize
              
                )
              
              
                )
              
              
                ;
              
              
        mAction 
              
                =
              
              
                register_module
              
              
                (
              
              
                "action"
              
              
                ,
              
               nn
              
                ::
              
              
                Linear
              
              
                (
              
              mHiddenSize
              
                ,
              
               actionNum
              
                )
              
              
                )
              
              
                ;
              
              
        mValue 
              
                =
              
              
                register_module
              
              
                (
              
              
                "value"
              
              
                ,
              
               nn
              
                ::
              
              
                Linear
              
              
                (
              
              mHiddenSize
              
                ,
              
               actionNum
              
                )
              
              
                )
              
              
                ;
              
              
                }
              
              
                virtual
              
               std
              
                ::
              
              tuple
              
                <
              
              Tensor
              
                ,
              
               Tensor
              
                >
              
              
                forward
              
              
                (
              
              Tensor x
              
                )
              
               override 
              
                {
              
              
        x 
              
                =
              
               mFC1
              
                -
              
              
                >
              
              
                forward
              
              
                (
              
              x
              
                )
              
              
                ;
              
              
        x 
              
                =
              
              
                dropout
              
              
                (
              
              x
              
                ,
              
              
                0.6
              
              
                ,
              
              
                is_training
              
              
                (
              
              
                )
              
              
                )
              
              
                ;
              
              
        x 
              
                =
              
              
                relu
              
              
                (
              
              x
              
                )
              
              
                ;
              
              
                return
              
               std
              
                ::
              
              
                make_tuple
              
              
                (
              
              mAction
              
                -
              
              
                >
              
              
                forward
              
              
                (
              
              x
              
                )
              
              
                ,
              
               mValue
              
                -
              
              
                >
              
              
                forward
              
              
                (
              
              x
              
                )
              
              
                )
              
              
                ;
              
              
                }
              
              
                private
              
              
                :
              
              
    nn
              
                ::
              
              Linear mFC1
              
                {
              
              
                nullptr
              
              
                }
              
              
                ;
              
              
    nn
              
                ::
              
              Linear mAction
              
                {
              
              
                nullptr
              
              
                }
              
              
                ;
              
              
    nn
              
                ::
              
              Linear mValue
              
                {
              
              
                nullptr
              
              
                }
              
              
                ;
              
              
                int64_t
              
               mHiddenSize
              
                {
              
              
                128
              
              
                }
              
              
                ;
              
              
                }
              
              
                ;
              
              
                class
              
              
                Policy
              
              
                :
              
              
                public
              
               torch
              
                ::
              
              nn
              
                ::
              
              Module 
              
                {
              
              
                public
              
              
                :
              
              
                Policy
              
              
                (
              
              
                const
              
               SpaceDesc
              
                &
              
               ss
              
                ,
              
              
                const
              
               SpaceDesc
              
                &
              
               as
              
                ,
              
               torch
              
                ::
              
              Device mDevice
              
                )
              
              
                :
              
              
                mActionSpaceType
              
              
                (
              
              as
              
                .
              
              stype
              
                )
              
              
                ,
              
              
                mActionNum
              
              
                (
              
              as
              
                .
              
              shape
              
                [
              
              
                0
              
              
                ]
              
              
                )
              
              
                ,
              
              
                mGen
              
              
                (
              
              kSeed
              
                )
              
              
                ,
              
              
                mUniformDist
              
              
                (
              
              
                0
              
              
                ,
              
              
                1.0
              
              
                )
              
              
                {
              
              
                if
              
              
                (
              
              ss
              
                .
              
              shape
              
                .
              
              
                size
              
              
                (
              
              
                )
              
              
                ==
              
              
                1
              
              
                )
              
              
                {
              
              
            mNet 
              
                =
              
               std
              
                ::
              
              make_shared
              
                <
              
              MLP
              
                >
              
              
                (
              
              ss
              
                .
              
              shape
              
                [
              
              
                0
              
              
                ]
              
              
                ,
              
               as
              
                .
              
              shape
              
                [
              
              
                0
              
              
                ]
              
              
                )
              
              
                ;
              
              
                }
              
              
                else
              
              
                {
              
              
            mNet 
              
                =
              
               std
              
                ::
              
              make_shared
              
                <
              
              CNN
              
                >
              
              
                (
              
              ss
              
                .
              
              shape
              
                ,
              
               as
              
                .
              
              shape
              
                [
              
              
                0
              
              
                ]
              
              
                )
              
              
                ;
              
              
                }
              
              
        mNet
              
                -
              
              
                >
              
              
                to
              
              
                (
              
              mDevice
              
                )
              
              
                ;
              
              
                register_module
              
              
                (
              
              
                "base"
              
              
                ,
              
               mNet
              
                )
              
              
                ;
              
              

        torch
              
                ::
              
              Tensor logits 
              
                =
              
               torch
              
                ::
              
              
                ones
              
              
                (
              
              
                {
              
              
                1
              
              
                ,
              
               as
              
                .
              
              shape
              
                [
              
              
                0
              
              
                ]
              
              
                }
              
              
                ,
              
               torch
              
                ::
              
              
                TensorOptions
              
              
                (
              
              mDevice
              
                )
              
              
                )
              
              
                ;
              
              
        mUniformCategorical 
              
                =
              
               std
              
                ::
              
              make_shared
              
                <
              
              Categorical
              
                >
              
              
                (
              
              
                nullptr
              
              
                ,
              
              
                &
              
              logits
              
                )
              
              
                ;
              
              
                }
              
              

    torch
              
                ::
              
              Tensor 
              
                forward
              
              
                (
              
              torch
              
                ::
              
              Tensor x
              
                )
              
              
                {
              
              
        x 
              
                =
              
               std
              
                ::
              
              get
              
                <
              
              
                0
              
              
                >
              
              
                (
              
              mNet
              
                -
              
              
                >
              
              
                forward
              
              
                (
              
              x
              
                )
              
              
                )
              
              
                ;
              
              
                return
              
               torch
              
                ::
              
              
                softmax
              
              
                (
              
              x
              
                ,
              
              
                1
              
              
                )
              
              
                ;
              
              
                }
              
              

    std
              
                ::
              
              tuple
              
                <
              
              torch
              
                ::
              
              Tensor
              
                ,
              
               torch
              
                ::
              
              Tensor
              
                >
              
              
                act
              
              
                (
              
              torch
              
                ::
              
              Tensor state
              
                )
              
              
                {
              
              
                auto
              
               output 
              
                =
              
              
                forward
              
              
                (
              
              state
              
                )
              
              
                ;
              
              
        std
              
                ::
              
              shared_ptr
              
                <
              
              Distribution
              
                >
              
               dist
              
                ;
              
              
                if
              
              
                (
              
              
                !
              
              mActionSpaceType
              
                .
              
              
                compare
              
              
                (
              
              
                "Discrete"
              
              
                )
              
              
                )
              
              
                {
              
              
            dist 
              
                =
              
               std
              
                ::
              
              make_shared
              
                <
              
              Categorical
              
                >
              
              
                (
              
              
                &
              
              output
              
                )
              
              
                ;
              
              
                }
              
              
                else
              
              
                {
              
              
                throw
              
               std
              
                ::
              
              
                logic_error
              
              
                (
              
              
                "Not implemented : action space"
              
              
                )
              
              
                ;
              
              
                }
              
              
                float
              
               rnd 
              
                =
              
              
                mUniformDist
              
              
                (
              
              mGen
              
                )
              
              
                ;
              
              
                float
              
               threshold 
              
                =
              
               kEpsEnd 
              
                +
              
              
                (
              
              kEpsStart 
              
                -
              
               kEpsEnd
              
                )
              
              
                *
              
              
                exp
              
              
                (
              
              
                -
              
              
                1.
              
              
                *
              
               mStep 
              
                /
              
               kEpsDecay
              
                )
              
              
                ;
              
              
                ++
              
              mStep
              
                ;
              
              
        torch
              
                ::
              
              Tensor action
              
                ;
              
              
                if
              
              
                (
              
              rnd 
              
                >
              
               threshold
              
                )
              
              
                {
              
              
            torch
              
                ::
              
              NoGradGuard no_grad
              
                ;
              
              
            action 
              
                =
              
               dist
              
                -
              
              
                >
              
              
                sample
              
              
                (
              
              
                )
              
              
                ;
              
              
                }
              
              
                else
              
              
                {
              
              
            torch
              
                ::
              
              NoGradGuard no_grad
              
                ;
              
              
            action 
              
                =
              
               mUniformCategorical
              
                -
              
              
                >
              
              
                sample
              
              
                (
              
              
                {
              
              
                1
              
              
                }
              
              
                )
              
              
                .
              
              
                squeeze
              
              
                (
              
              
                -
              
              
                1
              
              
                )
              
              
                ;
              
              
                }
              
              
                auto
              
               log_probs 
              
                =
              
               dist
              
                -
              
              
                >
              
              
                log_prob
              
              
                (
              
              action
              
                )
              
              
                ;
              
              
                return
              
               std
              
                ::
              
              
                make_tuple
              
              
                (
              
              action
              
                ,
              
               log_probs
              
                )
              
              
                ;
              
              
                }
              
              
                private
              
              
                :
              
              
    std
              
                ::
              
              string mActionSpaceType
              
                ;
              
              
                int32_t
              
               mActionNum
              
                ;
              
              
                int64_t
              
               mHiddenSize
              
                {
              
              
                128
              
              
                }
              
              
                ;
              
              
    std
              
                ::
              
              shared_ptr
              
                <
              
              Net
              
                >
              
               mNet
              
                ;
              
              
                uint64_t
              
               mStep
              
                {
              
              
                0
              
              
                }
              
              
                ;
              
              

    std
              
                ::
              
              mt19937 mGen
              
                ;
              
              
    std
              
                ::
              
              uniform_real_distribution
              
                <
              
              
                float
              
              
                >
              
               mUniformDist
              
                ;
              
              
    std
              
                ::
              
              shared_ptr
              
                <
              
              Categorical
              
                >
              
               mUniformCategorical
              
                ;
              
              
                }
              
              
                ;
              
            
          

其中的Categorical類為Categorical distribution相關計算,可以根據PyTorch中的python版本重寫成C++。

最后,將上面的C++實現編譯成so。根據實際情況在CMakeLists.txt中加入:

            
              
                ..
              
              .
set
              
                (
              
              CMAKE_CXX_STANDARD 11
              
                )
              
              

find_package
              
                (
              
              Torch REQUIRED
              
                )
              
              

set
              
                (
              
              NRL_INCLUDE_DIRS
    src
    
              
                ${TORCH_INCLUDE_DIRS}
              
              
                )
              
              

file
              
                (
              
              GLOB NRL_SOURCES1 
              
                "src/*.cpp"
              
              
                )
              
              
list
              
                (
              
              APPEND NRL_SOURCES 
              
                ${NRL_SOURCES1}
              
              
                )
              
              
message
              
                (
              
              STATUS 
              
                "sources: 
                
                  ${NRL_SOURCES}
                
                "
              
              
                )
              
              

add_subdirectory
              
                (
              
              third_party/pybind11
              
                )
              
              
add_subdirectory
              
                (
              
              third_party/spdlog
              
                )
              
              

pybind11_add_module
              
                (
              
              nativerl 
              
                ${NRL_SOURCES}
              
              
                )
              
              
target_include_directories
              
                (
              
              nativerl PRIVATE 
              
                ${NRL_INCLUDE_DIRS}
              
              
                )
              
              
target_link_libraries
              
                (
              
              nativerl PRIVATE spdlog::spdlog 
              
                ${TORCH_LIBRARIES}
              
              
                )
              
              
                ..
              
              .

            
          

假設編譯出的so位于build目錄下,python腳本為example/simple.py。則可以通過命令開始訓練:

            
              PYTHONPATH
              
                =
              
              ./build python -m example.simple

            
          

正常的話可以看到類似的訓練過程log及結果,基本和python版本一致。

            
              
                [
              
              2019-06-22 13:42:22.533
              
                ]
              
              
                [
              
              info
              
                ]
              
               state space type:Box shape size:1

              
                [
              
              2019-06-22 13:42:22.534
              
                ]
              
              
                [
              
              info
              
                ]
              
               action space type:Discrete, shape size:1

              
                [
              
              2019-06-22 13:42:22.534
              
                ]
              
              
                [
              
              info
              
                ]
              
               Training on GPU 
              
                (
              
              CUDA
              
                )
              
              
Episode 0	 Last reward: 29.00	 step: 29	 Average reward: 10.95
Episode 10	 Last reward: 17.00	 step: 17	 Average reward: 14.73
Episode 20	 Last reward: 12.00	 step: 12	 Average reward: 17.40
Episode 30	 Last reward: 15.00	 step: 15	 Average reward: 24.47
Episode 40	 Last reward: 18.00	 step: 18	 Average reward: 26.22
Episode 50	 Last reward: 18.00	 step: 18	 Average reward: 23.69
Episode 60	 Last reward: 72.00	 step: 72	 Average reward: 30.21
Episode 70	 Last reward: 19.00	 step: 19	 Average reward: 28.83
Episode 80	 Last reward: 29.00	 step: 29	 Average reward: 32.13
Episode 90	 Last reward: 15.00	 step: 15	 Average reward: 29.64
Episode 100	 Last reward: 30.00	 step: 30	 Average reward: 27.88
Episode 110	 Last reward: 12.00	 step: 12	 Average reward: 26.14
Episode 120	 Last reward: 28.00	 step: 28	 Average reward: 26.32
Episode 130	 Last reward: 11.00	 step: 11	 Average reward: 31.20
Episode 140	 Last reward: 112.00	 step: 112	 Average reward: 35.26
Episode 150	 Last reward: 40.00	 step: 40	 Average reward: 37.14
Episode 160	 Last reward: 40.00	 step: 40	 Average reward: 36.84
Episode 170	 Last reward: 15.00	 step: 15	 Average reward: 41.91
Episode 180	 Last reward: 63.00	 step: 63	 Average reward: 49.78
Episode 190	 Last reward: 21.00	 step: 21	 Average reward: 44.70
Episode 200	 Last reward: 46.00	 step: 46	 Average reward: 41.83
Episode 210	 Last reward: 80.00	 step: 80	 Average reward: 51.55
Episode 220	 Last reward: 151.00	 step: 151	 Average reward: 57.82
Episode 230	 Last reward: 176.00	 step: 176	 Average reward: 62.80
Episode 240	 Last reward: 19.00	 step: 19	 Average reward: 63.17
Episode 250	 Last reward: 134.00	 step: 134	 Average reward: 74.02
Episode 260	 Last reward: 46.00	 step: 46	 Average reward: 71.35
Episode 270	 Last reward: 118.00	 step: 118	 Average reward: 85.88
Episode 280	 Last reward: 487.00	 step: 487	 Average reward: 112.74
Episode 290	 Last reward: 95.00	 step: 95	 Average reward: 139.41
Episode 300	 Last reward: 54.00	 step: 54	 Average reward: 149.20
Episode 310	 Last reward: 417.00	 step: 417	 Average reward: 138.42
Episode 320	 Last reward: 500.00	 step: 500	 Average reward: 179.29
Episode 330	 Last reward: 71.00	 step: 71	 Average reward: 195.88
Episode 340	 Last reward: 309.00	 step: 309	 Average reward: 216.82
Episode 350	 Last reward: 268.00	 step: 268	 Average reward: 214.21
Episode 360	 Last reward: 243.00	 step: 243	 Average reward: 210.89
Episode 370	 Last reward: 266.00	 step: 266	 Average reward: 200.03
Episode 380	 Last reward: 379.00	 step: 379	 Average reward: 220.06
Episode 390	 Last reward: 500.00	 step: 500	 Average reward: 316.20
Episode 400	 Last reward: 500.00	 step: 500	 Average reward: 369.46
Episode 410	 Last reward: 500.00	 step: 500	 Average reward: 421.84
Episode 420	 Last reward: 500.00	 step: 500	 Average reward: 453.20
Episode 430	 Last reward: 500.00	 step: 500	 Average reward: 471.98
Solved. Running reward: 475.9764491024681, Last reward: 500


            
          

更多文章、技術交流、商務合作、聯系博主

微信掃碼或搜索:z360901061

微信掃一掃加我為好友

QQ號聯系: 360901061

您的支持是博主寫作最大的動力,如果您喜歡我的文章,感覺我的文章對您有幫助,請用微信掃描下面二維碼支持博主2元、5元、10元、20元等您想捐的金額吧,狠狠點擊下面給點支持吧,站長非常感激您!手機微信長按不能支付解決辦法:請將微信支付二維碼保存到相冊,切換到微信,然后點擊微信右上角掃一掃功能,選擇支付二維碼完成支付。

【本文對您有幫助就好】

您的支持是博主寫作最大的動力,如果您喜歡我的文章,感覺我的文章對您有幫助,請用微信掃描上面二維碼支持博主2元、5元、10元、自定義金額等您想捐的金額吧,站長會非常 感謝您的哦!!!

發表我的評論
最新評論 總共0條評論
主站蜘蛛池模板: 精品国产成人在线 | 国产精品 第1页 | 日本黄色一级片视频 | 欧美高清色视频在线播放 | 久久亚洲视频 | aaaaaaa片毛片免费观看 | 日韩影院在线观看 | 亚洲高清免费观看 | 日韩大尺度电影在线观看 | 国产欧美久久一区二区三区 | 日本黄色片免费看 | 91在线亚洲精品专区 | 99精品欧美一区二区三区综合在线 | 亚洲一区二区三区在线免费观看 | 91短视频版官网 | 亚洲国产精久久久久久久 | 亚洲精品综合 | 国产精品福利视频免费观看 | 午夜精品视频在线看 | 激情插插插 | 国产一区亚洲一区 | 成人免费视频观看视频 | 亚洲精品一区二区三区四区 | 刑事侦缉档案1 | 久久综合婷婷香五月 | 日本免费大片免费视频 | 国产在线aaa片一区二区99 | 国产精品久久久久久久午夜 | 国产九九视频在线观看 | 狠狠躁夜夜躁人人爽视频 | 免费很黄很色裸乳在线观看 | 亚洲成人免费在线 | 米奇精品一区二区三区在线观看 | 色噜噜噜噜噜在线观看网站 | 黑人插插| 蜜桃日本免费MV免费播放 | 一级毛片丰满 出奶水 | 久久cao| 日韩在线欧美 | 国产亚洲精品精品国产亚洲综合 | 中文字幕亚洲综合 |