Skip to content

Commit 3ba5f8b

Browse files
Refactors the Predicate class in the PredPatt module to introduce a new PredicateType enumeration for better type safety and clarity. Updates the documentation to reflect changes in predicate type handling, enhancing usability and consistency across the module. Modifies various components to utilize the new enumeration, ensuring a more robust implementation of predicate types.
1 parent 196e64c commit 3ba5f8b

File tree

10 files changed

+141
-110
lines changed

10 files changed

+141
-110
lines changed

decomp/semantics/predpatt/core/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,20 @@
77

88
from .argument import Argument, sort_by_position
99
from .options import PredPattOpts
10-
from .predicate import AMOD, APPOS, NORMAL, POSS, Predicate, argument_names, no_color
10+
from .predicate import (
11+
Predicate,
12+
PredicateType,
13+
argument_names,
14+
no_color,
15+
)
1116
from .token import Token
1217

1318

1419
__all__ = [
15-
"AMOD",
16-
"APPOS",
17-
"NORMAL",
18-
"POSS",
1920
"Argument",
2021
"PredPattOpts",
2122
"Predicate",
23+
"PredicateType",
2224
"Token",
2325
"argument_names",
2426
"no_color",

decomp/semantics/predpatt/core/predicate.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,41 @@
1-
"""Predicate class for representing extracted predicates.
2-
3-
This module contains the Predicate class which represents predicates
4-
extracted from dependency parses, including their arguments and
5-
various predicate types (normal, possessive, appositive, adjectival).
1+
"""Predicate representation for semantic role labeling in PredPatt.
2+
3+
This module defines the core predicate structures used in the PredPatt system
4+
for extracting and representing predicates from dependency parses. It handles
5+
various predicate types including verbal, possessive, appositional, and
6+
adjectival predicates.
7+
8+
Key Components
9+
--------------
10+
:class:`Predicate`
11+
Main class representing a predicate with its root token, arguments, and
12+
predicate type. Supports different predicate types (normal, possessive,
13+
appositive, adjectival).
14+
15+
:class:`PredicateType`
16+
Enumeration defining the four types of predicates that PredPatt can extract:
17+
NORMAL, POSS, APPOS, and AMOD.
18+
19+
:func:`argument_names`
20+
Utility function to generate alphabetic names for arguments (?a, ?b, etc.)
21+
for display and debugging purposes.
22+
23+
:func:`sort_by_position`
24+
Helper function to sort items by their position attribute, used for
25+
ordering tokens and arguments.
26+
27+
Predicate Types
28+
---------------
29+
The module defines a :class:`PredicateType` enum with four values:
30+
- ``PredicateType.NORMAL``: Standard verbal predicates
31+
- ``PredicateType.POSS``: Possessive predicates
32+
- ``PredicateType.APPOS``: Appositional predicates
33+
- ``PredicateType.AMOD``: Adjectival modifier predicates
634
"""
735

836
from __future__ import annotations
937

38+
import enum
1039
from typing import TYPE_CHECKING
1140

1241
from ..typing import T
@@ -23,11 +52,18 @@
2352

2453
ColorFunc = Callable[[str, str], str]
2554

26-
# Predicate type constants
27-
NORMAL = "normal"
28-
POSS = "poss"
29-
APPOS = "appos"
30-
AMOD = "amod"
55+
56+
class PredicateType(str, enum.Enum):
57+
"""Enumeration of predicate types in PredPatt.
58+
59+
Inherits from str to maintain backward compatibility with string comparisons.
60+
"""
61+
NORMAL = "normal" # Standard verbal predicates
62+
POSS = "poss" # Possessive predicates
63+
APPOS = "appos" # Appositional predicates
64+
AMOD = "amod" # Adjectival modifier predicates
65+
66+
3167

3268

3369
def argument_names(args: list[T]) -> dict[T, str]:
@@ -84,8 +120,8 @@ class Predicate:
84120
The Universal Dependencies module to use (default: dep_v1).
85121
rules : list, optional
86122
List of rules that led to this predicate's extraction.
87-
type_ : str, optional
88-
Type of predicate (NORMAL, POSS, APPOS, or AMOD).
123+
type_ : PredicateType, optional
124+
Type of predicate (PredicateType.NORMAL, POSS, APPOS, or AMOD).
89125
90126
Attributes
91127
----------
@@ -99,7 +135,7 @@ class Predicate:
99135
The UD version module being used.
100136
arguments : list[Argument]
101137
List of arguments for this predicate.
102-
type : str
138+
type : PredicateType
103139
Type of predicate.
104140
tokens : list[Token]
105141
List of tokens forming the predicate phrase.
@@ -110,7 +146,7 @@ def __init__(
110146
root: Token,
111147
ud: UDSchema = dep_v1,
112148
rules: list[Rule] | None = None,
113-
type_: str = NORMAL
149+
type_: PredicateType = PredicateType.NORMAL
114150
) -> None:
115151
"""Initialize a Predicate."""
116152
self.root = root
@@ -149,7 +185,7 @@ def identifier(self) -> str:
149185
Identifier in format 'pred.{type}.{position}.{arg_positions}'.
150186
"""
151187
arg_positions = '.'.join(str(a.position) for a in self.arguments)
152-
return f'pred.{self.type}.{self.position}.{arg_positions}'
188+
return f'pred.{self.type.value}.{self.position}.{arg_positions}'
153189

154190

155191
def has_token(self, token: Token) -> bool:
@@ -229,7 +265,7 @@ def share_subj(self, other: Predicate) -> bool | None:
229265
"""
230266
subj = self.subj()
231267
other_subj = other.subj()
232-
# use the exact same pattern as original to ensure identical behavior
268+
# check both subjects exist before comparing positions
233269
if subj is None or other_subj is None:
234270
return None
235271
return subj.position == other_subj.position
@@ -266,7 +302,7 @@ def is_broken(self) -> bool | None:
266302
return True
267303
if any(not a.tokens for a in self.arguments):
268304
return True
269-
if self.type == POSS and len(self.arguments) != 2:
305+
if self.type == PredicateType.POSS and len(self.arguments) != 2:
270306
return True
271307
return None
272308

@@ -288,12 +324,12 @@ def _format_predicate(self, name: dict[Argument, str], c: ColorFunc = no_color)
288324
# collect tokens and arguments
289325
x = sort_by_position(self.tokens + self.arguments)
290326

291-
if self.type == POSS:
327+
if self.type == PredicateType.POSS:
292328
# possessive format: "?a 's ?b"
293329
assert len(self.arguments) == 2
294-
return f'{name[self.arguments[0]]} {self.type} {name[self.arguments[1]]}'
330+
return f'{name[self.arguments[0]]} {self.type.value} {name[self.arguments[1]]}'
295331

296-
elif self.type in {APPOS, AMOD}:
332+
elif self.type in {PredicateType.APPOS, PredicateType.AMOD}:
297333
# appositive/adjectival format: "?a is/are [rest]"
298334
# find governor argument
299335
gov_arg = None
@@ -381,7 +417,7 @@ def format(
381417
name = argument_names(self.arguments)
382418
for arg in self.arguments:
383419
if (arg.isclausal() and arg.root.gov in self.tokens and
384-
self.type == NORMAL):
420+
self.type == PredicateType.NORMAL):
385421
s = c('SOMETHING', 'yellow') + ' := ' + arg.phrase()
386422
else:
387423
s = c(arg.phrase(), 'green')

decomp/semantics/predpatt/extraction/engine.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616

1717
if TYPE_CHECKING:
1818
from ..core.argument import Argument
19-
from ..core.predicate import Predicate
19+
from ..core.predicate import Predicate, PredicateType
2020
from ..core.token import Token
2121
from ..parsing.udparse import DepTriple, UDParse
2222
from ..rules.base import Rule
23+
else:
24+
# import at runtime to avoid circular imports
25+
from ..core.predicate import PredicateType
2326

24-
# predicate type constants
25-
NORMAL, POSS, APPOS, AMOD = ("normal", "poss", "appos", "amod")
2627

2728

2829
_PARSER = None
@@ -329,7 +330,7 @@ def identify_predicate_roots(self) -> list[Predicate]: # noqa: C901
329330

330331
roots = {}
331332

332-
def nominate(root: Token, rule: Rule, type_: str = NORMAL) -> Predicate:
333+
def nominate(root: Token, rule: Rule, type_: PredicateType = PredicateType.NORMAL) -> Predicate:
333334
"""Create or update a predicate instance with rules.
334335
335336
Parameters
@@ -338,8 +339,8 @@ def nominate(root: Token, rule: Rule, type_: str = NORMAL) -> Predicate:
338339
The root token of the predicate.
339340
rule : Rule
340341
The rule that identified this predicate.
341-
type_ : str, optional
342-
The predicate type (NORMAL, POSS, APPOS, AMOD).
342+
type_ : PredicateType, optional
343+
The predicate type (PredicateType.NORMAL, POSS, APPOS, AMOD).
343344
344345
Returns
345346
-------
@@ -360,17 +361,17 @@ def nominate(root: Token, rule: Rule, type_: str = NORMAL) -> Predicate:
360361

361362
# Special predicate types (conditional on options)
362363
if self.options.resolve_appos and e.rel == self.ud.appos:
363-
nominate(e.dep, R.D(), APPOS)
364+
nominate(e.dep, R.D(), PredicateType.APPOS)
364365

365366
if self.options.resolve_poss and e.rel == self.ud.nmod_poss:
366-
nominate(e.dep, R.V(), POSS)
367+
nominate(e.dep, R.V(), PredicateType.POSS)
367368

368369
# If resolve amod flag is enabled, then the dependent of an amod
369370
# arc is a predicate (but only if the dependent is an
370371
# adjective). We also filter cases where ADJ modifies ADJ.
371372
if (self.options.resolve_amod and e.rel == self.ud.amod
372373
and e.dep.tag == postag.ADJ and e.gov.tag != postag.ADJ):
373-
nominate(e.dep, R.E(), AMOD)
374+
nominate(e.dep, R.E(), PredicateType.AMOD)
374375

375376
# Avoid 'dep' arcs, they are normally parse errors.
376377
# Note: we allow amod, poss, and appos predicates, even with a dep arc.
@@ -480,7 +481,7 @@ def argument_extract(self, predicate: Predicate) -> list[Argument]: # noqa: C90
480481
# Nominal modifiers (h1 rule) - exclude AMOD predicates
481482
elif (e.rel is not None and
482483
(e.rel.startswith(self.ud.nmod) or e.rel.startswith(self.ud.obl))
483-
and predicate.type != AMOD):
484+
and predicate.type != PredicateType.AMOD):
484485
arguments.append(Argument(e.dep, self.ud, [R.H1()]))
485486

486487
# Clausal arguments (k rule)
@@ -499,19 +500,19 @@ def argument_extract(self, predicate: Predicate) -> list[Argument]: # noqa: C90
499500
arguments.append(Argument(tr.dep, self.ud, [R.H2()]))
500501

501502
# Special predicate type arguments
502-
if predicate.type == AMOD:
503+
if predicate.type == PredicateType.AMOD:
503504
# i rule: AMOD predicates get their governor
504505
if predicate.root.gov is None:
505506
raise ValueError(f"AMOD predicate {predicate.root} must have a governor but gov is None")
506507
arguments.append(Argument(predicate.root.gov, self.ud, [R.I()]))
507508

508-
elif predicate.type == APPOS:
509+
elif predicate.type == PredicateType.APPOS:
509510
# j rule: APPOS predicates get their governor
510511
if predicate.root.gov is None:
511512
raise ValueError(f"APPOS predicate {predicate.root} must have a governor but gov is None")
512513
arguments.append(Argument(predicate.root.gov, self.ud, [R.J()]))
513514

514-
elif predicate.type == POSS:
515+
elif predicate.type == PredicateType.POSS:
515516
# w1 rule: POSS predicates get their governor
516517
if predicate.root.gov is None:
517518
raise ValueError(f"POSS predicate {predicate.root} must have a governor but gov is None")
@@ -698,7 +699,7 @@ def _argument_resolution(self, events: list[Predicate]) -> list[Predicate]: # n
698699
# Portuguese. Without it, miss a lot of arguments.
699700
for p in sort_by_position(events):
700701
if (not p.has_subj()
701-
and p.type == NORMAL
702+
and p.type == PredicateType.NORMAL
702703
and p.root.gov_rel not in {self.ud.csubj, self.ud.csubjpass}
703704
and (p.root.gov_rel is None or not p.root.gov_rel.startswith(self.ud.acl))
704705
and not p.has_borrowed_arg()
@@ -793,7 +794,7 @@ def expand_coord(self, predicate: Predicate) -> list[Predicate]: # noqa: C901
793794
import itertools
794795

795796
# Don't expand amod unless resolve_conj is enabled
796-
if not self.options.resolve_conj or predicate.type == AMOD:
797+
if not self.options.resolve_conj or predicate.type == PredicateType.AMOD:
797798
predicate.arguments = [arg for arg in predicate.arguments if arg.tokens]
798799
if not predicate.arguments:
799800
return []
@@ -989,7 +990,7 @@ def _pred_phrase_extract(self, predicate: Predicate) -> None:
989990
from ..rules import predicate_rules as R # noqa: N812
990991

991992
assert predicate.tokens == []
992-
if predicate.type == POSS:
993+
if predicate.type == PredicateType.POSS:
993994
predicate.tokens = [predicate.root]
994995
return
995996
predicate.tokens.extend(self.subtree(predicate.root,
@@ -1178,7 +1179,7 @@ def _simple_arg(self, pred: Predicate, arg: Argument) -> bool:
11781179
"""
11791180
from ..rules import predicate_rules as R # noqa: N812
11801181

1181-
if pred.type == POSS:
1182+
if pred.type == PredicateType.POSS:
11821183
return True
11831184
if (pred.root.gov_rel in self.ud.ADJ_LIKE_MODS
11841185
and pred.root.gov == arg.root):

decomp/semantics/predpatt/utils/linearization.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@
1818
from collections.abc import Iterator
1919

2020
from ..core.argument import Argument
21-
from ..core.predicate import Predicate
21+
from ..core.predicate import Predicate, PredicateType
2222
from ..core.token import Token
2323
from ..extraction.engine import PredPattEngine
2424
from ..utils.ud_schema import DependencyRelationsV1, DependencyRelationsV2
2525

2626
UDSchema = type[DependencyRelationsV1] | type[DependencyRelationsV2]
2727
TokenIterator = Iterator[tuple[int, str]]
28+
else:
29+
# import at runtime to avoid circular imports
30+
from ..core.predicate import PredicateType
2831

2932

3033
class HasPosition(Protocol):
@@ -41,11 +44,6 @@ class HasChildren(Protocol):
4144

4245
T = TypeVar('T', bound=HasPosition)
4346

44-
# Import constants directly to avoid circular imports
45-
NORMAL = "normal"
46-
POSS = "poss"
47-
AMOD = "amod"
48-
APPOS = "appos"
4947

5048
# Regex patterns for parsing linearized forms
5149
RE_ARG_ENC = re.compile(r"\^\(\( | \)\)\$")
@@ -348,7 +346,7 @@ def flatten_pred(pred: Predicate, opt: LinearizedPPOpts, ud: UDSchema) -> tuple[
348346
args = pred.arguments
349347
child_preds = pred.children if hasattr(pred, 'children') else []
350348

351-
if pred.type == POSS:
349+
if pred.type == PredicateType.POSS:
352350
arg_i = 0
353351
# Only take the first two arguments into account.
354352
for y in sort_by_position(args[:2] + child_preds):
@@ -358,7 +356,7 @@ def flatten_pred(pred: Predicate, opt: LinearizedPPOpts, ud: UDSchema) -> tuple[
358356
arg_i += 1
359357
if arg_i == 1:
360358
# Generate the special ``poss'' predicate with label.
361-
poss = POSS + (PRED_HEADER if opt.distinguish_header
359+
poss = PredicateType.POSS.value + (PRED_HEADER if opt.distinguish_header
362360
else PRED_SUF)
363361
ret += [phrase_and_enclose_arg(arg_y, opt), poss]
364362
else:
@@ -371,7 +369,7 @@ def flatten_pred(pred: Predicate, opt: LinearizedPPOpts, ud: UDSchema) -> tuple[
371369
ret.append(repr_y)
372370
return ' '.join(ret), False
373371

374-
if pred.type in {AMOD, APPOS}:
372+
if pred.type in {PredicateType.AMOD, PredicateType.APPOS}:
375373
# Special handling for `amod` and `appos` because the target
376374
# relation `is/are` deviates from the original word order.
377375
arg0 = None

decomp/semantics/predpatt/utils/visualization.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,15 @@ def format_predicate(
9696
str
9797
Formatted predicate string with argument placeholders
9898
"""
99-
from decomp.semantics.predpatt.core.predicate import AMOD, APPOS, POSS
99+
from decomp.semantics.predpatt.core.predicate import PredicateType
100100

101101
ret = []
102102
args = predicate.arguments
103103

104-
if predicate.type == POSS:
105-
return ' '.join([name[args[0]], c(POSS, 'yellow'), name[args[1]]])
104+
if predicate.type == PredicateType.POSS:
105+
return ' '.join([name[args[0]], c(PredicateType.POSS.value, 'yellow'), name[args[1]]])
106106

107-
if predicate.type in {AMOD, APPOS}:
107+
if predicate.type in {PredicateType.AMOD, PredicateType.APPOS}:
108108
# Special handling for `amod` and `appos` because the target
109109
# relation `is/are` deviates from the original word order.
110110
arg0 = None
@@ -173,7 +173,7 @@ def format_predicate_instance(
173173
str
174174
Formatted predicate instance with arguments listed below
175175
"""
176-
from decomp.semantics.predpatt.core.predicate import NORMAL
176+
from decomp.semantics.predpatt.core.predicate import PredicateType
177177

178178
lines = []
179179
name = argument_names(predicate.arguments)
@@ -190,7 +190,7 @@ def format_predicate_instance(
190190
# Format arguments
191191
for arg in predicate.arguments:
192192
if (arg.isclausal() and arg.root.gov in predicate.tokens and
193-
predicate.type == NORMAL):
193+
predicate.type == PredicateType.NORMAL):
194194
s = c('SOMETHING', 'yellow') + ' := ' + arg.phrase()
195195
else:
196196
s = c(arg.phrase(), 'green')

0 commit comments

Comments
 (0)