GithubHelp home page GithubHelp logo

uber-research / lanegcn Goto Github PK

View Code? Open in Web Editor NEW
481.0 9.0 132.0 20.31 MB

[ECCV2020 Oral] Learning Lane Graph Representations for Motion Forecasting

Home Page: https://arxiv.org/abs/2007.13732

License: Other

Python 98.34% Shell 1.66%
self-driving motion-estimation graph-neural-networks artificial-intelligence motion-forecasting

lanegcn's People

Contributors

chenyuntc avatar jonathanbaker7 avatar wqi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

lanegcn's Issues

dilated_nbrs bug

Is it that mat = mat * mat should be modified to mat = mat * csr in function dilated_nbrs?

Learning Rate Drop

Hi! Thank you for your cool project!

I have a question which might be stupid. I notice that in the repository no warm-up and dropping learning rate on every iteration are involved. Since these techniques are normal for other deep learning applications, I am wondering if the current style in LaneGCN works better or the structure of LaneGCN can already be good without such tricks.

Thanks!

run preprocess_data.py raise EOFError

(lanegcn) zzj@zzj-OMEN-25L-Desktop-GT12-1xxx:/mnt/data/pycharmmm/LaneGCN-master$ python preprocess_data.py -m lanegcn

71%|████████████████████████████████████████████████▋ | 1030/1459 [06:31<02:42, 2.63it/s]

Traceback (most recent call last):
File "preprocess_data.py", line 415, in
main()
File "preprocess_data.py", line 56, in main
val(config)
File "preprocess_data.py", line 130, in val
for i, data in enumerate(tqdm(val_loader)):
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/tqdm/std.py", line 1185, in iter
for obj in iterable:
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 345, in next
data = self._next_data()
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 838, in _next_data
return self._process_data(data)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
data.reraise()
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/_utils.py", line 395, in reraise
raise self.exc_type(msg)
IndexError: Caught IndexError in DataLoader worker process 6.
Original Traceback (most recent call last):
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/mnt/data/pycharmmm/LaneGCN-master/data.py", line 85, in getitem
data = self.get_obj_feats(data)
File "/mnt/data/pycharmmm/LaneGCN-master/data.py", line 149, in get_obj_feats
orig = data['trajs'][0][19].copy().astype(np.float32)
IndexError: index 19 is out of bounds for axis 0 with size 16

Exception in thread Thread-2:
Traceback (most recent call last):
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/threading.py", line 926, in _bootstrap_inner
self.run()
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/threading.py", line 870, in run
self._target(*self._args, **self._kwargs)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/utils/data/_utils/pin_memory.py", line 25, in _pin_memory_loop
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/queues.py", line 113, in get
return _ForkingPickler.loads(res)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/site-packages/torch/multiprocessing/reductions.py", line 294, in rebuild_storage_fd
fd = df.detach()
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/resource_sharer.py", line 57, in detach
with _resource_sharer.get_connection(self._id) as conn:
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/resource_sharer.py", line 87, in get_connection
c = Client(address, authkey=process.current_process().authkey)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/connection.py", line 498, in Client
answer_challenge(c, authkey)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/connection.py", line 747, in answer_challenge
response = connection.recv_bytes(256) # reject large message
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/connection.py", line 216, in recv_bytes
buf = self._recv_bytes(maxlength)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/connection.py", line 407, in _recv_bytes
buf = self._recv(4)
File "/home/zzj/anaconda3/envs/lanegcn/lib/python3.7/multiprocessing/connection.py", line 383, in _recv
raise EOFError
EOFError

What can i do?
I use one 2080s in ubuntu20.04, memery=16G, pycharm max memery size =4G

inference time

Sorry I just post a random question here. It appears every object need to go through the network once to predict trajectory. I am wondering about the online inference time of your network (maybe after onnx). Say if you have 20 objects in the map, can that be done in real time (say within couple of mili-seconds)?

### Has anyone encountered the same issue below?

93%|██████████████████████████████████▍ | 5996/6436 [1:11:10<07:28, 1.02s/it]Traceback (most recent call last):
File "/home/luyu/miniconda3/envs/lanegcn-new/lib/python3.7/multiprocessing/resource_sharer.py", line 142, in _serve
with self._listener.accept() as conn:
File "/home/luyu/miniconda3/envs/lanegcn-new/lib/python3.7/multiprocessing/connection.py", line 456, in accept
answer_challenge(c, self._authkey)
File "/home/luyu/miniconda3/envs/lanegcn-new/lib/python3.7/multiprocessing/connection.py", line 742, in answer_challenge
message = connection.recv_bytes(256) # reject large message
File "/home/luyu/miniconda3/envs/lanegcn-new/lib/python3.7/multiprocessing/connection.py", line 216, in recv_bytes
buf = self._recv_bytes(maxlength)
File "/home/luyu/miniconda3/envs/lanegcn-new/lib/python3.7/multiprocessing/connection.py", line 407, in _recv_bytes
buf = self._recv(4)
File "/home/luyu/miniconda3/envs/lanegcn-new/lib/python3.7/multiprocessing/connection.py", line 379, in _recv
chunk = read(handle, remaining)
ConnectionResetError: [Errno 104] Connection reset by peer
killed

confused about the 'theta' param in actor preprocessing

HI @chenyuntc, thans for your source code,
I'm confused about the theta param in data:get_obj_feats recently.
I usually get theta from actan2( (pos[19].y - pos[18].y) , (pos[19].x - pos[18].x)).
what's your calculation specital mean here?

theta = np.pi - np.arctan2(pos[18].y-pos[19].y, pos[18].x - pos[19].x)

BR~

Getting data Forbidden

bash get_data.sh

--2023-07-07 02:37:29-- https://s3.amazonaws.com/argoai-argoverse/hd_maps.tar.gz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.229.120, 52.216.39.56, 52.217.207.96, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.229.120|:443... connected.
HTTP request sent, awaiting response... 403 Forbidden
2023-07-07 02:37:30 ERROR 403: Forbidden.

Cannot download the data

When I execute this command python preprocess_data.py -m lanegcn, I encountered the following problem:

Traceback (most recent call last):
File "preprocess_data.py", line 21, in
from data import ArgoDataset as Dataset, from_numpy, ref_copy, collate_fn
File "/home/ht1/LaneGCN/data.py", line 12, in
from argoverse.map_representation.map_api import ArgoverseMap
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/argoverse/map_representation/map_api.py", line 21, in
from argoverse.utils.cv2_plotting_utils import get_img_contours
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/argoverse/utils/cv2_plotting_utils.py", line 9, in
from .calibration import CameraConfig, proj_cam_to_uv
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/argoverse/utils/calibration.py", line 14, in
from argoverse.utils.camera_stats import (
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/argoverse/utils/camera_stats.py", line 8, in
from argoverse.sensor_dataset_config import ArgoverseConfig
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/argoverse/sensor_dataset_config.py", line 55, in
cfg = hydra.compose(config_name=f"{DATASET_NAME}.yaml")
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/compose.py", line 33, in compose
with_log_configuration=False,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/hydra.py", line 550, in compose_config
from_shell=from_shell,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/config_loader_impl.py", line 150, in load_configuration
from_shell=from_shell,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/config_loader_impl.py", line 244, in _load_configuration_impl
skip_missing=run_mode == RunMode.MULTIRUN,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/defaults_list.py", line 724, in create_defaults_list
skip_missing=skip_missing,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/defaults_list.py", line 695, in _create_defaults_list
skip_missing=skip_missing,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/defaults_list.py", line 343, in _create_defaults_tree
overrides=overrides,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/defaults_list.py", line 420, in _create_defaults_tree_impl
return _expand_virtual_root(repo, root, overrides, skip_missing)
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/defaults_list.py", line 268, in _expand_virtual_root
overrides=overrides,
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/defaults_list.py", line 427, in _create_defaults_tree_impl
config_not_found_error(repo=repo, tree=root)
File "/home/ht1/anaconda3/envs/lanegcn/lib/python3.7/site-packages/hydra/_internal/defaults_list.py", line 776, in config_not_found_error
options=options,
hydra.errors.MissingConfigException: Cannot find primary config 'argoverse-v1.1.yaml'. Check that it's in your config search path.

Config search path:
provider=hydra, path=pkg://hydra.conf
provider=main, path=pkg://argoverse.config
provider=schema, path=structured://

multi agent trajectory prediction?

Thank you for sharing your work and neat code!

In your paper (3.5 Learning),
the loss is used for training multiple agent. (M is the number of agent in the scene)

But the pretrained checkpoint (36.000.ckpt) seems only predicting the first agent for computing scores such as minADE.

Do you also have the weight checkpoint for multi agent?
If so, can you share one?
Training process is too heavy for my server ;(

How were graph["left_pairs"] and graph["left_pairs"] defined in the lane graph? Is there a bug of these keys?

I printed the value of graph["left_pairs"] and graph["right_pairs"] and found that in graph["left_pairs"] there are some duplicate node pairs, for example "[37,64]" and "[64,37]", but it is hard to understand why node "37" and "64" can be left neighbor of each other. But there are no repeated pairs in the graph["right_pairs"]. So why?
If graph["left_pairs"] and the reverse of graph["right_pairs"] should be equal? For instance, graph["left_pairs"]:[[1,2],[3,4]],
graph["right_pairs"]:[[2,1],[4,3]].

The printed value of graph["left_pairs"] and graph["right_pairs"] are as follows:
len graph[left_pairs] 67 len graph[right_pairs] 27 graph[left_pairs] tensor([[15, 91], [19, 86], [24, 97], [29, 30], [30, 29], [31, 43], [32, 68], [33, 60], [34, 39], [35, 51], [36, 66], [37, 64], [38, 69], [39, 34], [40, 67], [41, 95], [42, 51], [43, 31], [44, 29], [45, 78], [46, 63], [47, 69], [48, 58], [49, 67], [50, 58], [51, 35], [52, 15], [53, 47], [54, 37], [55, 56], [56, 55], [57, 96], [58, 48], [59, 75], [61, 48], [62, 82], [63, 72], [64, 37], [65, 49], [66, 98], [67, 49], [68, 32], [69, 47], [70, 64], [71, 30], [72, 63], [73, 74], [74, 82], [75, 59], [76, 77], [77, 76], [79, 85], [81, 80], [82, 74], [83, 85], [84, 91], [85, 83], [88, 83], [89, 35], [90, 72], [91, 15], [93, 92], [94, 98], [95, 41], [96, 57], [97, 24], [98, 66]], device='cuda:0') graph[right_pairs] tensor([[15, 52], [29, 44], [30, 71], [35, 89], [37, 54], [47, 53], [48, 61], [49, 65], [51, 42], [58, 50], [60, 33], [63, 46], [64, 70], [66, 36], [67, 40], [69, 38], [72, 90], [74, 73], [78, 45], [80, 81], [82, 62], [83, 88], [85, 79], [86, 19], [91, 84], [92, 93], [98, 94]], device='cuda:0')

Preprocessed data download link

Dear authors,
thank you for sharing your code. I was trying to train landgcn from scratch and had problems generating the training data. The data generation code freezes in the beginning and the download link seems to have connection issues. I was wondering if you have any idea about that.
Thank you very much!

How can I get the results about test dataset

I can get result about val dataset by running command: python test.py -m=lanegcn --weight=[ckpt PATH] --split=val --preprocess=True

but when I runned this command

python test.py -m=lanegcn --weight=[ckpt PATH] --split=val --preprocess=True

nothing printed. How can I get results about test dataset?

Evaluation results mismatch raw source file

Recently I wanna visualize your prediction result but I got a little confused about the correspondence between the idx of preprocessed val data and $ARGO_RAW_DATA/val/data/argo_id.csv.
In ArgoTestDataset :

LaneGCN/data.py

Line 382 in 7e9b51d

data['argo_id'] = int(self.avl.seq_list[idx].name[:-4]) #160547

I download your preprocessed val data and found the prediction paths with argo_id doesn't match the source file $ARGO_RAW_DATA/val/data/argo_id.csv. My dataset source data is download from argo website(version 1.1), I want to know whether the order of preprocessed data matches argo_id? Or did I do something wrong?

Preprocessed data is very slow

Dear authors,
thank you for sharing your code.
I download the dataset from argoverse, then I want to preprocess the data. When I use 'python preprocess_data.py -m lanegcn', it take 5 hours but nothing output, the cpu occupancy rate is high but gpu is low.
Thank you very much!

left and right in the lane graph

Hi,
How do you find "left" and "right" connections using the using 'pre', 'suc', 'pre_pairs', 'suc_pairs', 'left_pairs', 'right_pairs' ?

How can I get the results in the Table 1 of the paper?

I run this command: python test.py -m lanegcn --weight=/absolute/path/to/36.000.ckpt --split=test, but nothing printed except this:
`2442it [15:14, 2.67it/s]^[[B^[[B
78143/78143--Return--
None

/laneGCN/test.py(114)main()
113 generate_forecasting_h5(preds, f"{config['save_dir']}/submit.h5") # this might take awhile
--> 114 import ipdb;ipdb.set_trace()
115 `

Pretrain Model

Hi all,

Thank you for open source the code! It has been a great help!

However, I am recently trying to use your pretrained model for inference. It seems the link in README is invalid. Therefore, I am wondering if you may be kind enough to update the link or offer me a the pretrained model. I indeed appreciate your help!

Best,

Ziqi

Training details

Hi,
According to the paper section 4.1 (implementation details), you use a batch size of 128 and train for 36 epochs with a learning rate 0.001 and decayed at 32 to 0.0001.

According to the provided code, the batch size is 32:

config["batch_size"] = 32

Does it give the same performance?

Also one more question about the loss function, can you give more insights for the classification loss? why do you need it, and have you tried training without it?

Thanks a lot for the great work.

Can you tell me the reason for randomness?

Hello, Thanks for your nice project.

I tried to train the model several times without code editing, then I found that the difference in performance of each trial was quite large.
There is no big difference when testing multiple times with one checkpoint, so randomness seems to occur during the learning process.
Do you have any idea what could be the reason?

Thank you!

"left" and "right" not in gragh

for k1 in ["left", "right"]:
    graph[k1] = dict()
    for k2 in ["u", "v"]:
        temp = [graphs[i][k1][k2] + counts[i] for i in range(batch_size)]
        temp = [
            x if x.dim() > 0 else graph["pre"][0]["u"].new().resize_(0)
            for x in temp
        ]
        graph[k1][k2] = torch.cat(temp)

KeyError: 'left'

Cannot download the pretrained model

Hi, @chenyuntc .

I want to download a pre-trained model.
Therefore, I pressed the "here" button shown in the picture below in your github page, but nothing happened.
Could you check if there is any error? Or is there another way to download the pre-trained model?

Thank you,
download_pretrained_model

IndexError: list index out of range

python test.py -m lanegcn --weight=/home/jovyan/LaneGCN/36.000.ckpt --split=test

0it [00:00, ?it/s]
Traceback (most recent call last):
File "/home/jovyan/LaneGCN/test.py", line 118, in
main()
File "/home/jovyan/LaneGCN/test.py", line 82, in main
for ii, data in tqdm(enumerate(data_loader)):
File "/home/venv/lib/python3.9/site-packages/tqdm/std.py", line 1178, in iter
for obj in iterable:
File "/home/venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 634, in next
data = self._next_data()
File "/home/venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 678, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/venv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/venv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/jovyan/LaneGCN/data.py", line 384, in getitem
data['argo_id'] = int(self.avl.seq_list[idx].name[:-4]) #160547
IndexError: list index out of range

Does it really generate graph['node_idcs']?

Hi, Thnanks for release the codes!

I have one question. When I generate graph through preprocess_data.py, I didn't find graph['node_idcs'] was saved in graph this dict, but during training lanegcn.py and executing the MapNet module, I found the following codes in the forward function of MapNet:
def forward(self, graph): if ( len(graph["feats"]) == 0 or len(graph["pre"][-1]["u"]) == 0 or len(graph["suc"][-1]["u"]) == 0 ): temp = graph["feats"] return ( temp.new().resize_(0), [temp.new().long().resize_(0) for x in graph["node_idcs"]], temp.new().resize_(0), )
Although there is no error when I running the original module lanegcn.py, Is it because every time the if condition won't be execute or the graph['node_idcs'] was really generated somewhere I didn't find? If this conditional statement is executed, the nodes feature of the output will be empty.

Memory issue for multi-gpu training

Hi,

Thanks for the great work.

I am trying to train using horovod with 4 GPUs (RTX2080Ti) with a cpu memory of 80 GB. However, after sometime and before it starts training the first epoch, I got the following error:

mpirun noticed that process rank 0 with PID 0 on node dagobert exited on signal 9 (Killed).

According to the horovod github, it seems an out of memory issue, Therefore, I would like to know the system requirements you have to train on 4 gpus. What are the gpu memory, cpu memory, number of cpus, etc. Maybe any advice to help training on multi-gpu?

Training is much slower than you described in paper.

Hi, I recently want to reproduce your result and can get the metric your described in paper but I got a problems that the training (almost 3 days) than you described in paper (less than 12 hours).

Environment:

  • 4 * Titan X (same as paper)
  • batch size 128 (4*32, same as paper)
  • change distribution framework from horovod to pytorch DDP since thehorovod framework is really hard to set up (even with official horovod docker I still got some errors I can't resolve)

Did I do something wrong? I'm sure that I use DDP correctly and also sure that the bottleneck of training speed is optimization (not IO or something else). Have others met the same problems like me?

how to visualization

Your work has been a great help for a beginner, but I don’t understand how to visualize the results obtained? Where are the imported cv2 modules used? Hope to get your reply, thank you again for your work.

Question: What does the "u" and "v" in data.py?

Hello,

Thanks for the great code.
Can you explain about "u" and "v" for graph['pre'] and graph['suc'] in data.py?

pre, suc = dict(), dict()
for key in ['u', 'v']:
    pre[key], suc[key] = [], []
for i, lane_id in enumerate(lane_ids):
    lane = lanes[lane_id]
    idcs = node_idcs[i]
    
    pre['u'] += idcs[1:]
    pre['v'] += idcs[:-1]
    if lane.predecessors is not None:
        for nbr_id in lane.predecessors:
            if nbr_id in lane_ids:
                j = lane_ids.index(nbr_id)
                pre['u'].append(idcs[0])
                pre['v'].append(node_idcs[j][-1])
            
    suc['u'] += idcs[:-1]
    suc['v'] += idcs[1:]
    if lane.successors is not None:
        for nbr_id in lane.successors:
            if nbr_id in lane_ids:
                j = lane_ids.index(nbr_id)
                suc['u'].append(idcs[-1])
                suc['v'].append(node_idcs[j][0])

and this also shows up in lanegcn.py


def graph_gather(graphs):
    batch_size = len(graphs)
    node_idcs = []
    count = 0
    counts = []
    for i in range(batch_size):
        counts.append(count)
        idcs = torch.arange(count, count + graphs[i]["num_nodes"]).to(
            graphs[i]["feats"].device
        )
        node_idcs.append(idcs)
        count = count + graphs[i]["num_nodes"]

    graph = dict()
    graph["idcs"] = node_idcs
    graph["ctrs"] = [x["ctrs"] for x in graphs]

    for key in ["feats", "turn", "control", "intersect"]:
        graph[key] = torch.cat([x[key] for x in graphs], 0)

    for k1 in ["pre", "suc"]:
        graph[k1] = []
        for i in range(len(graphs[0]["pre"])):
            graph[k1].append(dict())
            for k2 in ["u", "v"]:
                graph[k1][i][k2] = torch.cat(
                    [graphs[j][k1][i][k2] + counts[j] for j in range(batch_size)], 0
                )

    for k1 in ["left", "right"]:
        graph[k1] = dict()
        for k2 in ["u", "v"]:
            temp = [graphs[i][k1][k2] + counts[i] for i in range(batch_size)]
            temp = [
                x if x.dim() > 0 else graph["pre"][0]["u"].new().resize_(0)
                for x in temp
            ]
            graph[k1][k2] = torch.cat(temp)
    return graph


train.py

When I run train.py, it shows an error, and according to my analysis, it does not enter the trained function, but instead has an error loading the data. Train1.py has the same error. Can you give me some advice? @chenyuntc

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.