GithubHelp home page GithubHelp logo

Comments (7)

BoYanSTKO avatar BoYanSTKO commented on May 13, 2024

here is my implementation:
https://github.com/BoYanSTKO/Practical_RL-coursera/blob/master/week4_approx/dqn_atari.ipynb

from practical_rl.

BoYanSTKO avatar BoYanSTKO commented on May 13, 2024

This is my implementation for the agent:

from keras.layers import Conv2D, Dense, Flatten, InputLayer
import keras
class DQNAgent:
    def __init__(self, name, state_shape, n_actions, epsilon=0, reuse=False):
        """A simple DQN agent"""
        with tf.variable_scope(name, reuse=reuse):
            
            # < Define your network body here. Please make sure you don't use any layers created elsewhere >
            self.network = keras.models.Sequential()
    
            # Keras ignores the first dimension in the input_shape, which is the batch size. 
            # So just use state_shape for the input shape
            self.network.add(Conv2D(16, (3, 3), strides=2, activation='relu', input_shape=state_shape))
            self.network.add(Conv2D(32, (3, 3), strides=2, activation='relu'))
            self.network.add(Conv2D(64, (3, 3), strides=2, activation='relu'))
            self.network.add(Flatten())
            self.network.add(Dense(256, activation='relu'))
            self.network.add(Dense(n_actions, activation='linear'))
            
            # prepare a graph for agent step
            self.state_t = tf.placeholder('float32', [None,] + list(state_shape))
            self.qvalues_t = self.get_symbolic_qvalues(self.state_t)
            
        self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=name)
        self.epsilon = epsilon

    def get_symbolic_qvalues(self, state_t):
        """takes agent's observation, returns qvalues. Both are tf Tensors"""
        
        qvalues = self.network(state_t)
        
        
        assert tf.is_numeric_tensor(qvalues) and qvalues.shape.ndims == 2, \
            "please return 2d tf tensor of qvalues [you got %s]" % repr(qvalues)
        assert int(qvalues.shape[1]) == n_actions
        
        return qvalues
    
    def get_qvalues(self, state_t):
        """Same as symbolic step except it operates on numpy arrays"""
        sess = tf.get_default_session()
        return sess.run(self.qvalues_t, {self.state_t: state_t})
    
    def sample_actions(self, qvalues):
        """pick actions given qvalues. Uses epsilon-greedy exploration strategy. """
        epsilon = self.epsilon
        batch_size, n_actions = qvalues.shape
        random_actions = np.random.choice(n_actions, size=batch_size)
        best_actions = qvalues.argmax(axis=-1)
        should_explore = np.random.choice([0, 1], batch_size, p = [1-epsilon, epsilon])
        return np.where(should_explore, random_actions, best_actions)

implementation for Q-learning part:

# compute q-values for NEXT states with target network
next_qvalues_target = target_network.get_symbolic_qvalues(next_obs_ph)

# compute state values by taking max over next_qvalues_target for all actions
# next_actions = tf.argmax(next_qvalues_target, axis=-1)
# next_state_values_target = tf.reduce_sum(tf.one_hot(next_actions, n_actions) * next_qvalues_target, axis=1)
next_state_values_target = tf.reduce_max(next_qvalues_target, axis=-1)

# compute Q_reference(s,a) as per formula above.
reference_qvalues = rewards_ph + gamma*next_state_values_target*is_not_done

# Define loss function for sgd.
td_loss = (current_action_qvalues - reference_qvalues) ** 2
td_loss = tf.reduce_mean(td_loss)

train_step = tf.train.AdamOptimizer(1e-3).minimize(td_loss, var_list=agent.weights)

from practical_rl.

etendue avatar etendue commented on May 13, 2024

Hi, I am facing the same issue as you had. By checking your repository, you reached the score >10. What is fix?

from practical_rl.

BoYanSTKO avatar BoYanSTKO commented on May 13, 2024

@etendue sorry for the late reply. I didn't change much. I changed the image preprocessing part according to the implementation in their later ipython notebook. Then I run more episodes and it worked. I feel like there are a lot of heuristics and instability in the model.

from practical_rl.

justheuristic avatar justheuristic commented on May 13, 2024

I feel like there are a lot of heuristics and instability in the model.

Unfortunately, you are correct. We had to choose between training time and stability. Judging by the amount of feedback on this assignment (Thanks!), we'll see if we can improve it later this month.

from practical_rl.

justheuristic avatar justheuristic commented on May 13, 2024

Chances are, we'll finally have a user-friendly and easy-to-run notebook soon enough thanks to @zshrav .

Btw @zshrav once it's ready, would you kindly post a link to the updated week4 here and close the issue?

from practical_rl.

justheuristic avatar justheuristic commented on May 13, 2024

The overhaul is complete (by zhrav@), feel free to open new issues if you have any... yknow... issues

from practical_rl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.