tmaxedu / kordpr Goto Github PK
View Code? Open in Web Editor NEWThis repo Implements "Dense Passage Retrieval for Open-Domain Question Answering" using Korean Dataset
License: Other
This repo Implements "Dense Passage Retrieval for Open-Domain Question Answering" using Korean Dataset
License: Other
rm = KobertBiEncoder()
rm.load("/home/chanhwi/HeriGPT/KorDPR/2050iter_model/2050iter_model.pt")
index = DenseFlatIndexer()
index.deserialize(path="/home/chanhwi/HeriGPT/KorDPR/2050iter_model")
위와 같은 코드를 실행하면
RuntimeError Traceback (most recent call last)
/home/chanhwi/HeriGPT/test.ipynb 셀 3 line 2
[1](vscode-notebook-cell://ssh-remote%2B163.152.20.133/home/chanhwi/HeriGPT/test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) rm = KobertBiEncoder()
----> [2](vscode-notebook-cell://ssh-remote%2B163.152.20.133/home/chanhwi/HeriGPT/test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1) rm.load("/home/chanhwi/HeriGPT/KorDPR/2050iter_model/2050iter_model.pt")
[3](vscode-notebook-cell://ssh-remote%2B163.152.20.133/home/chanhwi/HeriGPT/test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2) # index = DenseFlatIndexer()
[4](vscode-notebook-cell://ssh-remote%2B163.152.20.133/home/chanhwi/HeriGPT/test.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3) # index.deserialize(path="/home/chanhwi/HeriGPT/KorDPR/2050iter_model")
File [~/HeriGPT/KorDPR/encoder.py:53](https://vscode-remote+ssh-002dremote-002b163-002e152-002e20-002e133.vscode-resource.vscode-cdn.net/home/chanhwi/HeriGPT/~/HeriGPT/KorDPR/encoder.py:53), in KobertBiEncoder.load(self, model_ckpt_path)
51 with open(model_ckpt_path, "rb") as f:
52 state_dict = torch.load(f)
---> 53 self.load_state_dict(state_dict)
54 logger.debug(f"model self.state_dict loaded from {model_ckpt_path}")
File [~/miniconda3/envs/heri/lib/python3.9/site-packages/torch/nn/modules/module.py:1482](https://vscode-remote+ssh-002dremote-002b163-002e152-002e20-002e133.vscode-resource.vscode-cdn.net/home/chanhwi/HeriGPT/~/miniconda3/envs/heri/lib/python3.9/site-packages/torch/nn/modules/module.py:1482), in Module.load_state_dict(self, state_dict, strict)
1477 error_msgs.insert(
1478 0, 'Missing key(s) in state_dict: {}. '.format(
1479 ', '.join('"{}"'.format(k) for k in missing_keys)))
1481 if len(error_msgs) > 0:
-> 1482 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
1483 self.__class__.__name__, "\n\t".join(error_msgs)))
1484 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for KobertBiEncoder:
Unexpected key(s) in state_dict: "passage_encoder.embeddings.position_ids", "query_encoder.embeddings.position_ids".
위와같은 오류가 발생합니다.
어떻게 해결 할 수 있을까요?
일단, 모델을 훈련시키기 않고, indexing만 시켜서 작동을 하는지 테스트하는 과정에서 아래와 같은 오류가 발생되었습니다. 이에, 수정한 사항을 기록에 남깁니다. 필요하신 분들은 참고 부탁드려요~
( @TmaxEdu 이 레포를 공개해주셔서 너무 감사해요~)
수정사항(1)
self.index.search_knn(query_vectors=out.cpu().numpy(), top_docs=k) -> k를 int로 감싸주어야 작동했습니다.
Traceback (most recent call last):
File "/home/dydtjd95/workspace/KorDPR/retriever.py", line 104, in <module>
retriever.retrieve(query=args.query, k=args.k)
File "/home/dydtjd95/workspace/KorDPR/retriever.py", line 75, in retrieve
result = self.index.search_knn(query_vectors=out.cpu().numpy(), top_docs=k)
File "/home/dydtjd95/workspace/KorDPR/indexers.py", line 111, in search_knn
scores, indexes = self.index.search(query_vectors, top_docs)
File "/home/dydtjd95/.conda/envs/kordpr/lib/python3.9/site-packages/faiss/class_wrappers.py", line 331, in replacement_search
assert k > 0
TypeError: '>' not supported between instances of 'str' and 'int'
수정사항(2)
수정한 코드
out = self.model(T(tok["input_ids"]).to(self.device), T(tok["attention_mask"]).to(self.device), "query")
수정한 이유
Traceback (most recent call last):
File "/home/dydtjd95/workspace/KorDPR/retriever.py", line 104, in <module>
retriever.retrieve(query=args.query, k=args.k)
File "/home/dydtjd95/workspace/KorDPR/retriever.py", line 74, in retrieve
out = self.model(T(tok["input_ids"]), T(tok["attention_mask"]), "query")
File "/home/dydtjd95/.conda/envs/kordpr/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/home/dydtjd95/workspace/KorDPR/encoder.py", line 42, in forward
return self.query_encoder(
File "/home/dydtjd95/.conda/envs/kordpr/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/home/dydtjd95/.conda/envs/kordpr/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1007, in forward
embedding_output = self.embeddings(
File "/home/dydtjd95/.conda/envs/kordpr/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/home/dydtjd95/.conda/envs/kordpr/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 231, in forward
inputs_embeds = self.word_embeddings(input_ids)
File "/home/dydtjd95/.conda/envs/kordpr/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "/home/dydtjd95/.conda/envs/kordpr/lib/python3.9/site-packages/torch/nn/modules/sparse.py", line 160, in forward
return F.embedding(
File "/home/dydtjd95/.conda/envs/kordpr/lib/python3.9/site-packages/torch/nn/functional.py", line 2210, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)
수정 완료한 코드 전문
def retrieve(self, query: str, k: int = 100):
"""주어진 쿼리에 대해 가장 유사도가 높은 passage를 반환합니다."""
self.model.eval() # 평가 모드
tok = self.tokenizer.batch_encode_plus([query])
with torch.no_grad():
out = self.model(T(tok["input_ids"]).to(self.device), T(tok["attention_mask"]).to(self.device), "query")
result = self.index.search_knn(query_vectors=out.cpu().numpy(), top_docs=int(k))```
KorQuadDataset 안의
def _load_data(self):
with open(self.korquad_path, "rt", encoding="utf8") as f:
data = json.load(f)
self.raw_json = data["data"]
logger.debug("data loaded into self.raw_json")
with open("title_passage_map.p", "rb") as f:
self.title_passage_map = pickle.load(f)
logger.debug("title passage mapping loaded into self.title_passage_map")
이 부분을 보면 title_passage_map.p이라는 파일이 필요한 데 이 파일은 어디서 생성되는 파일인가요?
File "retriever.py", line 106, in <module>
retr_acc = retriever.val_top_k_acc()
File "retriever.py", line 43, in val_top_k_acc
q, q_mask, p_id, p, p_mask, a, a_mask = batch
ValueError: not enough values to unpack (expected 7, got 5)
retriever.py 실행시 다음과 같은 오류가 발생하는데 어떻게 해결해야 할까요?
해당 부분에서 사용되는 데이터셋은
valid_dataset = KorQuadDataset("dataset/KorQuAD_v1.0_dev.json")
인데, 이 class는 q, q_mask, p_id, p, p_mask
만을 리턴하는 것으로 보입니다.
답변 주시면 감사하겠습니다.
안녕하세요, 질문에 앞서 누구나 이해하고 따라가기 쉽게 코드 작성 및 공유 해주셔서 감사드립니다.
막 자연어 처리에 입문한 저에게는 정말 큰 도움이 되고 있습니다.
다름이 아니라 readme.md 에 작성하신 instruction을 따라가던 중 step 3에서 자꾸 걸려서 질문 드립니다.
현재 리포에서 활용되는 한국 위키 데이터 덤프 파일 kowiki-20220120-pages-articles.xml
을 공유받을 수 있을까요?
현재로써는 20200120 KorWiki dump data은 접근이 불가능한 상황입니다. 그래서 placeholder로 2024 데이터를 활용하여 이후 step을 따라가려 했지만, step4을 마치고 step5인 훈련으로 넘어가려 할 때에 dev split에서의 전처리된 데이터가 빈 데이터로 전처리가 된 것을 확인하였습니다.
그리고 그 이유를 살펴 본 결과, 제목과 매칭하는 과정에서 매칭되는 제목이 전혀 없어 빈 데이터로 전처리가 된 것을 확인하였습니다.
저는 현재 공유해주신 코드 그대로 전처리하여 훈련시키는 것과 더불어, 조금 다른 방식의 전처리 방식도 시도를 해보고 싶은 상황인지라 본 리포에서 활용중인 kowiki-20220120-pages-articles.xml
을 가지고 계시다면 혹시 공유 받을 수 있을지 여쭙고 싶습니다.
감사합니다.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.