Comments (2)
Hi! @tpatzelt Thanks for your interest in our work!
The OOD Detection part is done by using some separate scripts on top of this repo. I'm putting two main scripts here.
- For generating experiment commands
# allennlp predict model/bow-sum[sst-2]-1 data/ForPredictors/sst2-dev.jsonl --cuda-device 0 --include-package allennlp_glue_patch --output-file output_dev.txt --predictor binary_sentiment_predictor --silent
command_template = "echo \"{output_name}\" && allennlp predict model/{model_id} data/ForPredictors/{target}.jsonl --cuda-device 4 --include-package allennlp_glue_patch --output-file {output_name}.pred.txt --predictor binary_sentiment_predictor --silent"
models = [
"bow-sum[sst-2]-1",
"word2vec-sum[sst-2]-5",
"word2vec-lstm[sst-2]-6",
"word2vec-cnn[sst-2]-8",
"glove-sum[sst-2]-10",
"glove-lstm[sst-2]-16",
"glove-cnn[sst-2]-18",
"roberta-large-pool[sst-2]-22",
"bert-large-pool[sst-2]-25",
"bert-base-pool[sst-2]-26"
]
targets = [
"20ng",
"multi30k",
"sst2-dev",
"wmt16",
"snli_concat",
"rte",
"yelp-am",
]
for m in models:
for t in targets:
print(command_template.format(
model_id=m,
target=t,
output_name="{}__{}".format(t, m)
))
- For collecting the result
import numpy as np
import sklearn.metrics as sk
recall_level_default = 0.95
def stable_cumsum(arr, rtol=1e-05, atol=1e-08):
"""Use high precision for cumsum and check that final value matches sum
Parameters
----------
arr : array-like
To be cumulatively summed as flat
rtol : float
Relative tolerance, see ``np.allclose``
atol : float
Absolute tolerance, see ``np.allclose``
"""
out = np.cumsum(arr, dtype=np.float64)
expected = np.sum(arr, dtype=np.float64)
if not np.allclose(out[-1], expected, rtol=rtol, atol=atol):
raise RuntimeError('cumsum was found to be unstable: '
'its last element does not correspond to sum')
return out
def fpr_and_fdr_at_recall(y_true, y_score, recall_level=recall_level_default, pos_label=None):
classes = np.unique(y_true)
if (pos_label is None and
not (np.array_equal(classes, [0, 1]) or
np.array_equal(classes, [-1, 1]) or
np.array_equal(classes, [0]) or
np.array_equal(classes, [-1]) or
np.array_equal(classes, [1]))):
raise ValueError("Data is not binary and pos_label is not specified")
elif pos_label is None:
pos_label = 1.
# make y_true a boolean vector
y_true = (y_true == pos_label)
# sort scores and corresponding truth values
desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
y_score = y_score[desc_score_indices]
y_true = y_true[desc_score_indices]
# y_score typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = np.where(np.diff(y_score))[0]
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
# accumulate the true positives with decreasing threshold
tps = stable_cumsum(y_true)[threshold_idxs]
fps = 1 + threshold_idxs - tps # add one because of zero-based indexing
thresholds = y_score[threshold_idxs]
recall = tps / tps[-1]
last_ind = tps.searchsorted(tps[-1])
sl = slice(last_ind, None, -1) # [last_ind::-1]
recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl]
cutoff = np.argmin(np.abs(recall - recall_level))
return fps[cutoff] / (np.sum(np.logical_not(y_true))) # , fps[cutoff]/(fps[cutoff] + tps[cutoff])
def get_measures(_pos, _neg, recall_level=recall_level_default):
pos = np.array(_pos[:]).reshape((-1, 1))
neg = np.array(_neg[:]).reshape((-1, 1))
examples = np.squeeze(np.vstack((pos, neg)))
labels = np.zeros(len(examples), dtype=np.int32)
labels[:len(pos)] += 1
auroc = sk.roc_auc_score(labels, examples)
aupr = sk.average_precision_score(labels, examples)
fpr = fpr_and_fdr_at_recall(labels, examples, recall_level)
return auroc, aupr, fpr
def show_performance(pos, neg, method_name='Ours', recall_level=recall_level_default):
'''
:param pos: 1's class, class to detect, outliers, or wrongly predicted
example scores
:param neg: 0's class scores
'''
auroc, aupr, fpr = get_measures(pos[:], neg[:], recall_level)
print('\t\t\t' + method_name)
print('FPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fpr))
print('AUROC:\t\t\t{:.2f}'.format(100 * auroc))
print('AUPR:\t\t\t{:.2f}'.format(100 * aupr))
# print('FDR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fdr))
return fpr, auroc, aupr
def print_measures(auroc, aupr, fpr, method_name='Ours', recall_level=recall_level_default):
print('\t\t\t\t' + method_name)
print('FPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fpr))
print('AUROC: \t\t\t{:.2f}'.format(100 * auroc))
print('AUPR: \t\t\t{:.2f}'.format(100 * aupr))
def print_measures_with_std(aurocs, auprs, fprs, method_name='Ours', recall_level=recall_level_default):
print('\t\t\t\t' + method_name)
print('FPR{:d}:\t\t\t{:.2f}\t+/- {:.2f}'.format(int(100 * recall_level), 100 * np.mean(fprs), 100 * np.std(fprs)))
print('AUROC: \t\t\t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(aurocs), 100 * np.std(aurocs)))
print('AUPR: \t\t\t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(auprs), 100 * np.std(auprs)))
def show_performance_comparison(pos_base, neg_base, pos_ours, neg_ours, baseline_name='Baseline',
method_name='Ours', recall_level=recall_level_default):
'''
:param pos_base: 1's class, class to detect, outliers, or wrongly predicted
example scores from the baseline
:param neg_base: 0's class scores generated by the baseline
'''
auroc_base, aupr_base, fpr_base = get_measures(pos_base[:], neg_base[:], recall_level)
auroc_ours, aupr_ours, fpr_ours = get_measures(pos_ours[:], neg_ours[:], recall_level)
print('\t\t\t' + baseline_name + '\t' + method_name)
print('FPR{:d}:\t\t\t{:.2f}\t\t{:.2f}'.format(
int(100 * recall_level), 100 * fpr_base, 100 * fpr_ours))
print('AUROC:\t\t\t{:.2f}\t\t{:.2f}'.format(
100 * auroc_base, 100 * auroc_ours))
print('AUPR:\t\t\t{:.2f}\t\t{:.2f}'.format(
100 * aupr_base, 100 * aupr_ours))
# print('FDR{:d}:\t\t\t{:.2f}\t\t{:.2f}'.format(
# int(100 * recall_level), 100 * fdr_base, 100 * fdr_ours))
import os, sys, json
argv = sys.argv
os.chdir("./pred")
base_file = "output_train.txt"
ood_file = "output_dev.txt"
def read_file(f):
print("read ->", f)
with open(f) as fin:
ret = []
for l in fin:
probs = json.loads(l)['probs']
# print(probs)
if sum(probs[:2]) < 0.0001:
continue
nw = max(probs[:2])/sum(probs[:2])
ret.append(-nw)
return ret
targets = [
"20ng",
"multi30k",
"sst2-dev",
"wmt16",
"snli_concat",
"rte",
"yelp-am",
]
import glob
all_results = {}
for filename in glob.glob("*.pred.txt"):
fileid = filename.split(".pred.txt")[0]
target, modelid = fileid.split("__")
all_results[(target, modelid)] = read_file(filename)
printed_results = []
fout = open("OOD.tsv", 'w')
baseline_target="sst2-dev"
#baseline_target="yelp-am"
import csv
writer = csv.writer(fout, delimiter="\t")
writer.writerow(["model_id", "target", "len", "fpr", "auroc", "aupr"])
for (target, modelid), score in all_results.items():
if len(score) < 100:
continue
else:
print(target, modelid)
printed_results = show_performance(score, all_results[(baseline_target, modelid)])
writer.writerow([modelid, target, len(score), printed_results[0], printed_results[1], printed_results[2]])
Let me know if you have any further questions. Hope this helps!
from nlp-robustness.
@camelop thank you. that looks already very helpful. I will try it out soon (little busy atm)
from nlp-robustness.
Related Issues (3)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from nlp-robustness.