Skip to content

Commit efef9b0

Browse files
authored
Merge pull request #53 from IndicoDataSolutions/lily/rationalized_clean
Lily/rationalized clean
2 parents 876c016 + dde6a4c commit efef9b0

File tree

2 files changed

+238
-18
lines changed

2 files changed

+238
-18
lines changed

enso/config_rationalized.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import indicoio
2+
from enso.mode import ModeKeys
3+
import multiprocessing
4+
5+
"""Constants to configure the rest of Enso."""
6+
7+
# Directory for storing data
8+
DATA_DIRECTORY = "Data"
9+
10+
# Directory for storing results
11+
RESULTS_DIRECTORY = "Results"
12+
13+
# Directory for storing features
14+
FEATURES_DIRECTORY = "Features"
15+
16+
# Directory for storing experiment results
17+
EXPERIMENT_NAME = "Rationales"
18+
19+
# Name of the csv used to store results
20+
RESULTS_CSV_NAME = "Results.csv"
21+
22+
# Datasets to featurize or run experiments on
23+
DATA = {
24+
# "Classify/AirlineComplaints",
25+
# "Classify/AirlineNegativity",cRep
26+
# "Classify/IMDB",
27+
# "Classify/Irony",
28+
# "Classify/MPQA",
29+
# "Classify/MovieReviews",
30+
# "Classify/NewYearsResolutions",
31+
# "Classify/PoliticalTweetAlignment",
32+
# "Classify/PoliticalTweetBias",
33+
# "Classify/PoliticalTweetClassification",
34+
# "Classify/PoliticalTweetSubjectivity",
35+
# "Classify/PoliticalTweetTarget",
36+
# "Classify/ReligiousTexts",
37+
# "Classify/ShortAnswer",
38+
# "Classify/SocialMediaDisasters",
39+
# "Classify/Subjectivity",
40+
# "Classify/TextSpam",
41+
# "Classify/SST-binary"
42+
# Seqence
43+
# 'SequenceLabeling/Reuters-128',
44+
# "SequenceLabeling/table_synth",
45+
# 'SequenceLabeling/bonds_new',
46+
# 'SequenceLabeling/tables',
47+
# 'SequenceLabeling/typed_cols',
48+
# 'SequenceLabeling/brown_all',
49+
# 'SequenceLabeling/brown_nouns',
50+
# 'SequenceLabeling/brown_verbs',
51+
# 'SequenceLabeling/brown_pronouns',
52+
# 'SequenceLabeling/brown_adverbs',
53+
# 'RationalizedClassify/short_bank_qualified',
54+
# 'RationalizedClassify/bank_qualified',
55+
# 'RationalizedClassify/evidence_inference',
56+
# 'RationalizedClassify/federal_tax',
57+
# "RationalizedClassify/short_federal_tax",
58+
# 'RationalizedClassify/interest_frequency',
59+
# "RationalizedClassify/short_interest_frequency",
60+
"RationalizedClassify/aviation",
61+
# "RationalizedClassify/movie_reviews",
62+
# "RationalizedClassify/mining_rationales",
63+
# "RationalizedClassify/mining_extractions",
64+
# "RationalizedClassify/insurance_rationales",
65+
# "RationalizedClassify/insurance_extractions",
66+
# "RationalizedClassify/mining",
67+
# "RationalizedClassify/insurance_rationales_precise",
68+
# 'RationalizedClassify/short_bank_qualified',
69+
# 'RationalizedClassify/bank_qualified',
70+
# 'RationalizedClassify/short_bank_qualified_fixed',
71+
# 'RationalizedClassify/bank_qualified_fixed',
72+
# 'RationalizedClassify/short_bank_qualified_precise',
73+
# 'RationalizedClassify/bank_qualified_precise',
74+
}
75+
76+
# Featurizers to activate
77+
FEATURIZERS = {
78+
"PlainTextFeaturizer",
79+
# "TextContextFeaturizer",
80+
# "IndicoStandard",
81+
"SpacyGloveFeaturizer",
82+
# "IndicoFastText",
83+
# "IndicoSentiment",
84+
# "IndicoElmo",
85+
# "IndicoTopics",
86+
# "IndicoFinance",
87+
# "IndicoTransformer",
88+
# "IndicoEmotion",
89+
# "IndicoFastText",
90+
# "SpacyCNNFeaturizer",
91+
}
92+
93+
# Experiments to run
94+
EXPERIMENTS = {
95+
# "FinetuneSequenceLabel",
96+
# "Proto",
97+
# "IndicoSequenceLabel"
98+
"LRBaselineNonRationalized",
99+
"DistReweightedGloveClassifierCV",
100+
'DistReweightedGloveByClassClassifierCV'
101+
# "RationaleInformedLRCV"
102+
# "FinetuneSeqBaselineRationalized",
103+
# "FinetuneClfBaselineNonRationalized",
104+
# "LogisticRegressionCV",
105+
# "KNNCV",
106+
# "TfidfKNN",
107+
# "TfidfLogisticRegression",
108+
# "KCenters",
109+
# "TfidfKCenters"
110+
# "SupportVectorMachineCV",
111+
}
112+
113+
# Metrics to compute
114+
METRICS = {
115+
# "Accuracy",
116+
"AccuracyRationalized",
117+
"MacroRocAucRationalized",
118+
# "MacroRocAuc",
119+
# "MacroCharF1",
120+
# "MacroCharRecall",
121+
# "MacroCharPrecision",
122+
}
123+
124+
# Test setup metadata
125+
TEST_SETUP = {
126+
"train_sizes": [20, 40, 60, 80, 100, 150, 200, 300, 400, 500],
127+
"n_splits": 5,
128+
# "samplers": ['RandomRationalized'],
129+
# "samplers": ["ImbalanceSampler"],
130+
"samplers": ["RandomRationalized"],
131+
"sampling_size": 0.2,
132+
"resamplers": ["NoResampler"]
133+
# "resamplers": ["RandomOverSampler"],
134+
}
135+
136+
# Visualizations to display
137+
VISUALIZATIONS = {"FacetGridVisualizer"}
138+
139+
# kwargs to pass directly into visualizations
140+
VISUALIZATION_OPTIONS = {
141+
"display": True,
142+
"save": True,
143+
"FacetGridVisualizer": {
144+
"x_tile": "Metric",
145+
"y_tile": "Dataset",
146+
"x_axis": "TrainSize",
147+
"y_axis": "Result",
148+
"lines": ["Experiment", "Featurizer", "Sampler", "Resampler"],
149+
"category": "merge",
150+
"cv": "mean",
151+
"filename": "TestResult",
152+
},
153+
}
154+
155+
MODE = ModeKeys.RATIONALIZED
156+
157+
N_GPUS = 0
158+
N_CORES = 1 # multiprocessing.cpu_count()
159+
160+
FIX_REQUIREMENTS = True
161+
162+
GOLD_FRAC = 0.05
163+
CORRUPTION_FRAC = 0.4
164+
165+
indicoio.config.api_key = ""
166+
167+
# If we have no experiment hyperparameters we hope to modify:
168+
EXPERIMENT_PARAMS = {}

enso/experiment/rationalized.py

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from enso.experiment.grid_search import GridSearch
1010
from finetune import Classifier, SequenceLabeler
1111
from sklearn.preprocessing import LabelBinarizer
12-
from collections import Counter, defaultdict
12+
from collections import Counter, defaultdict, OrderedDict
1313

1414
class RationalizedGridSearch(GridSearch):
1515
def fit(self, X, y):
@@ -101,7 +101,10 @@ def cleanup(self):
101101

102102
@Registry.register_experiment(ModeKeys.RATIONALIZED, requirements=[("Featurizer", "PlainTextFeaturizer")])
103103
class ReweightedGloveClassifier(ClassificationExperiment):
104+
"""
105+
Weights words by their proportional occurrence as rationales, smoothed
104106
107+
"""
105108
NLP = None
106109

107110
def __init__(self, *args, **kwargs):
@@ -211,14 +214,10 @@ def fit(self, X, Y):
211214
rationales.append([{**label, "label": l[1]} for label in l[0]])
212215
else:
213216
rationales.append([])
214-
rationale_texts = [
215-
rationale['text']
216-
for doc in rationales
217-
for rationale in doc
218-
]
219-
docs = np.asarray([self.NLP(str(x), disable=['ner', 'tagger', 'textcat']) for x in X])
220-
rationale_docs = np.asarray([self.NLP(rationale) for rationale in rationale_texts if len(rationale)])
221-
self._train_rationale_model(docs, rationale_docs)
217+
rationale_texts = [rationale["text"] for doc in rationales for rationale in doc]
218+
docs = self.NLP.pipe(X, disable=["ner", "tagger", "textcat"])
219+
rationale_docs = np.asarray([self.NLP(rationale) if len(rationale) else None for rationale in rationale_texts])
220+
self._train_rationale_model(docs, rationale_docs, labels=labels)
222221

223222
doc_vects = np.asarray([self._featurize(doc) for doc in docs])
224223
resampled_x, resampled_y = self.resample(doc_vects, labels)
@@ -236,12 +235,15 @@ def predict(self, X, **kwargs):
236235

237236
@Registry.register_experiment(ModeKeys.RATIONALIZED, requirements=[("Featurizer", "PlainTextFeaturizer")])
238237
class DistReweightedGloveClassifierCV(BaseRationaleGridSearch):
238+
"""
239+
Weights words by cosine similarity to the mean of the rationale vector representations
239240
240-
def _train_rationale_model(self, docs, rationale_docs):
241+
"""
242+
def _train_rationale_model(self, docs, rationale_docs, labels=None):
241243
rationale_vecs = [
242-
doc.vector / np.linalg.norm(doc.vector)
243-
for doc in rationale_docs
244-
if doc.has_vector and np.any(np.nonzero(doc.vector))
244+
doc.vector / np.linalg.norm(doc.vector)
245+
for doc in rationale_docs
246+
if doc and doc.has_vector and np.any(np.nonzero(doc.vector))
245247
]
246248
rationale_proto = np.mean(rationale_vecs, axis=0)
247249
self.normalized_rationale_proto = rationale_proto / np.linalg.norm(rationale_proto)
@@ -264,21 +266,71 @@ def _featurize(self, doc):
264266

265267

266268
@Registry.register_experiment(ModeKeys.RATIONALIZED, requirements=[("Featurizer", "PlainTextFeaturizer")])
267-
class RationaleInformedLRCV(BaseRationaleGridSearch):
269+
class DistReweightedGloveByClassClassifierCV(BaseRationaleGridSearch):
270+
"""
271+
Weights words by cosine similarity to the mean of the rationale vector representations per class
272+
273+
"""
274+
def _train_rationale_model(self, docs, rationale_docs, labels=None):
275+
rationale_vecs_by_class = defaultdict(list)
276+
for doc, label in zip(rationale_docs, labels):
277+
if doc and doc.has_vector and np.any(np.nonzero(doc.vector)):
278+
rationale_vecs_by_class[label].append(
279+
doc.vector / np.linalg.norm(doc.vector)
280+
)
281+
rationale_proto_by_class = {
282+
label: np.mean(rationale_vecs, axis=0)
283+
for label, rationale_vecs in rationale_vecs_by_class.items()
284+
}
285+
self.normalized_rationale_proto_by_class = OrderedDict({
286+
label: rationale_proto / np.linalg.norm(rationale_proto)
287+
for label, rationale_proto in rationale_proto_by_class.items()
288+
})
268289

269-
def _train_rationale_model(self, docs, rationale_docs):
290+
def _rationale_weight(self, word, rationale_proto):
291+
cosine_sim = np.dot(word.vector / np.linalg.norm(word.vector), rationale_proto)
292+
return cosine_sim
293+
294+
def _featurize(self, doc):
295+
"""
296+
Take the mean representation, reweighted by the representations of
297+
each of the rationale prototypes
298+
299+
"""
300+
doc_vects = []
301+
for rationale_proto in self.normalized_rationale_proto_by_class.values():
302+
doc_vects.append(
303+
np.mean(
304+
[
305+
token.vector * self._rationale_weight(token, rationale_proto)
306+
for token in doc if self._valid(token)
307+
],
308+
axis=0
309+
)
310+
)
311+
doc_vect = np.mean(doc_vects, axis=0)
312+
313+
return doc_vect / np.linalg.norm(doc_vect)
314+
315+
316+
@Registry.register_experiment(ModeKeys.RATIONALIZED, requirements=[("Featurizer", "PlainTextFeaturizer")])
317+
class RationaleInformedLRCV(BaseRationaleGridSearch):
318+
"""
319+
Reweight document vectors by their similarity to a rationale vector, predicted by an LR model
320+
"""
321+
def _train_rationale_model(self, docs, rationale_docs, labels=None):
270322
rationale_vecs = [
271323
doc.vector / np.linalg.norm(doc.vector)
272324
for doc in rationale_docs
273325
if doc.has_vector and np.any(np.nonzero(doc.vector))
274326
]
275327
rationale_targets = [1] * len(rationale_vecs)
276328
background_vecs = [
277-
doc.vector / np.linalg.norm(doc.vector)
278-
for doc in rationale_docs
329+
doc.vector / np.linalg.norm(doc.vector)
330+
for doc in docs
279331
if doc.has_vector and np.any(np.nonzero(doc.vector))
280332
]
281-
background_targets = [0] * len(rationale_vecs)
333+
background_targets = [0] * len(background_vecs)
282334
X = rationale_vecs + background_vecs
283335
Y = rationale_targets + background_targets
284336

0 commit comments

Comments
 (0)