tud-amr / social_vrnn Goto Github PK
View Code? Open in Web Editor NEWLicense: GNU General Public License v3.0
License: GNU General Public License v3.0
I would like to know what value of the truncated_backprop_length
parameter was used for producing the results displayed in the publication.
The train.sh
script hints at the value being either 1 or 12. Was there a fixed value chosen for all experiments across the 5 datasets from UCY/ETH?
On this topic, I would just like to know for sure, the observation time T_O mentioned in the article is 8, however the few trainings shown as examples in the train.sh
script indicate either 0 or 7. Was the value of prev_horizon
chosen as 8 in your experiments?
Finally, could you tell me what is the reason behind the implementation of both truncated backpropagation through time truncated_backprop_length
and the choice of a set observation time prev_horizon
? I understand that they both operate in different manners in the code. But as implemented in this way currently, the input feature vectors for the past trajectories are essentially shifted versions of the velocity signal of the query agent, as such:
batch_vel.shape = [batch_size, truncated_backprop_length, input_state_dim * (prev_horizon + 1)]
Let:
truncated_backprop_length = 3
prev_horizon = 4
Then:
batch_vel[0] =
[[vx_-2, vy_-2, vx_-3, vy_-3, vx_-4, vy_-4, vx_-5, vy_-5, vx_-6, vy_-6],
[vx_-1, vy_-1, vx_-2, vy_-2, vx_-3, vy_-3, vx_-4, vy_-4, vx_-5, vy_-5],
[vx_0, vy_0, vx_-1, vy_-1, vx_-2, vy_-2, vx_-3, vy_-3, vx_-4, vy_-4]]
This results in an input feature vector containing duplicate values of the velocity signal across time truncations within the batch (ie, in the second dimension of the batch).
In the Support file, we can see the rotate_batch_to_local_frame() function which is being invoked when passing the --rotated_grid true
flag while training.
The function calculates the heading angle of the agent, and then applies a rotation of the user's coordinates using this heading angle.
social_vrnn/src/data_utils/Support.py
Line 107 in 707cada
Traceback (most recent call last):
File "train.py", line <line where getBatch gets called>, in <module>
batch_x, batch_vel, batch_pos, batch_goal, batch_grid, batch_ped_grid, batch_y, batch_pos_target, other_agents_pos, new_epoch = data_prep.getBatch()
File "/path/to/DataHandlerLSTM.py", line <getBatch line>, in getBatch
_, self.batch_y = sup.rotate_batch_to_local_frame(self.batch_y, self.batch_x)
File "/path/to/Support.py", line <rotate_batch_to_local_frame line>, in rotate_batch_to_local_frame
bx[batch_idx, tbp_step, 2:] = np.dot(rot_mat, bx[batch_idx, tbp_step, 2:])
File "<__array_function__ internals>", line 6, in dot
ValueError: shapes (2,2) and (34,) not aligned: 2 (dim 1) != 34 (dim 0)
This is due to the fact that the velocity/position array bx
stores the coordinates in a single dimension, while the application of the rotation matrix demands that the positions and velocities be organised in a format which allows for this rotation.
Looking at the implementation of this line of code, I see that the goal is to rotate every coordinates contained within bx except the first two values, which correspond to the position at time of origin t_0.
We can fix this by commenting out the line of code, and iterating over the desired coordinates of the array: replace the highlighted code line with this:
for pred_step in range(1, bx.shape[2] // 2):
bx[batch_idx, tbp_step, 2 * pred_step:2 * (pred_step + 1)] = np.dot(
rot_mat, bx[batch_idx, tbp_step, 2 * pred_step:2 * (pred_step + 1)]
)
This should behave the same way as the original implementation. However I am not 100% certain whether this implementation is correct still. Rotating along the heading angle in this way essentially means that, although the velocity components will be properly transformed, the positions won't, as they are still expressed in terms of the global world coordinate system, and as such, will be rotated along this coordinate system, while they need to be rotated along the point of origin of the agent, is that correct?
I will continue my investigation of this bit of code, if I make any advancements or discoveries here, I will post them.
(Also, I will try to submit a pull request with the fix if I can, the problem is I am working on a number of elements of the repository, and I wouldn't want my additional implementations altering the state of the original repo. If there is a way for me to submit a pull request containing only specific lines / bug fixes without also submitting my unrelated / irrelevant work, that would be nice to know).
reproducing the error:
try to execute the ./dowload_data.sh
script
output: receive ERROR 503 (Service Unavailable) upon sending the https request for reaching the trained_models.zip
file
I suppose this is due to a broken link for this specific file. Downloading the data.zip
file works just fine.
social_vrnn/src/models/SocialVRNN.py
Line 234 in 707cada
Here we have the definition of the KL annealing coefficient, which does not correspond to the definition mentioned in the publication.
The publication mentions:
"We used a KL annealing coefficient λ = tanh( (step - 10^4) /10^3 ), with step as the training step."
Which of the two is correct?
social_vrnn/src/data_utils/DataHandlerLSTM.py
Line 871 in 707cada
The definition of the current velocity array is incorrect. We can compare it with the way the current_pos array is implemented (which is the correct way), right above it
for prev_step in range(self.prev_horizon,-1,-1):
current_pos = np.array([trajectory.pose_vec[start_idx + tbp_step - prev_step, 0], trajectory.pose_vec[
start_idx + tbp_step - prev_step, 1]])
current_vel = np.array([trajectory.vel_vec[start_idx + tbp_step, 0], trajectory.vel_vec[
start_idx + tbp_step - prev_step - prev_step, 1]])
current_vel should be instead:
current_vel = np.array([trajectory.vel_vec[start_idx + tbp_step - prev_step, 0], trajectory.vel_vec[
start_idx + tbp_step - prev_step, 1]])
The article mentions that training of the SocialVRNN architecture is done on 4 of the five UCY/ETH datasets, for its performance to be then evaluated on the last one.
However, the DataHandler class implemented in the DataHandlerLSTM.py is only able to load one dataset at a time. I would like to reproduce the experiments conduced in the article, how should I perform the training on 4 datasets with the DataHandler provided?
I am reporting a problem with DataHandler module import:
/social_vrnn/src$ ./train.sh
Using Python 3
Traceback (most recent call last):
File "train.py", line 20, in
from src.data_utils import DataHandler as dh
ImportError: cannot import name 'DataHandler' from 'src.data_utils' (unknown location)
Using Python 3
Traceback (most recent call last):
File "train.py", line 20, in
from src.data_utils import DataHandler as dh
ImportError: cannot import name 'DataHandler' from 'src.data_utils' (unknown location)
Using Python 3
Traceback (most recent call last):
File "train.py", line 20, in
from src.data_utils import DataHandler as dh
ImportError: cannot import name 'DataHandler' from 'src.data_utils' (unknown location)
Using Python 3
Traceback (most recent call last):
File "train.py", line 20, in
from src.data_utils import DataHandler as dh
ImportError: cannot import name 'DataHandler' from 'src.data_utils' (unknown location)
Using Python 3
Traceback (most recent call last):
File "train.py", line 20, in
from src.data_utils import DataHandler as dh
ImportError: cannot import name 'DataHandler' from 'src.data_utils' (unknown location)
Thank you
While going through code, I tried to download pre-trained models through "dowload_data.sh" but the link for downloading pre-trained model has expired.
So, I request you to share the pre-trained model of your amazing work.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.