Skip to content

Commit 2b8c282

Browse files
committed
feat(labelling): add primitive labelling algorithm
Implement an algorithm that is a vast simplification of CRF. It simply iterates over the sentence once, starting from the beginning, and chooses the best label for the word by evaluating each feature function.
1 parent 087e4bc commit 2b8c282

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

labelling/labelling.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package labelling
2+
3+
// FeatureFunction is a feature function for the model
4+
type FeatureFunction func(sentence []string, i int, labelCurr string, labelPrev string) float64
5+
6+
// Feature includes the weight and feature function for a feature
7+
type Feature struct {
8+
Weight float64
9+
Value FeatureFunction
10+
}
11+
12+
// FindBestLabelling determines the best labeling for the given sentence
13+
func FindBestLabelling(sentence []string, labels []string, features []Feature) []string {
14+
labelling := make([]string, 0)
15+
16+
for i := 0; i < len(sentence); i++ {
17+
bestScore, bestLabel, currentScore := -1.0, "", 0.0
18+
prevLabel := ""
19+
if i > 0 {
20+
prevLabel = labelling[i-1]
21+
}
22+
23+
for j := 0; j < len(labels); j++ {
24+
for k := 0; k < len(features); k++ {
25+
currentScore += (features[k].Weight * features[k].Value(sentence, i, labels[j], prevLabel))
26+
}
27+
28+
if currentScore > bestScore {
29+
bestScore = currentScore
30+
bestLabel = labels[j]
31+
}
32+
33+
currentScore = 0
34+
}
35+
36+
labelling = append(labelling, bestLabel)
37+
}
38+
39+
return labelling
40+
}

labelling/labelling_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package labelling
2+
3+
import (
4+
"reflect"
5+
"strings"
6+
"testing"
7+
)
8+
9+
func stringInArray(list []string, s string) bool {
10+
for i := 0; i < len(list); i++ {
11+
if strings.ToLower(list[i]) == strings.ToLower(s) {
12+
return true
13+
}
14+
}
15+
16+
return false
17+
}
18+
19+
func isQuantityAtBeginning(sentence []string, i int, labelCurr string, labelPrev string) float64 {
20+
if i == 0 && stringInArray([]string{"1", "2", "3", "4", "5", "6", "7", "8", "9"}, sentence[i]) && labelCurr == "quantity" {
21+
return 1
22+
}
23+
24+
return 0
25+
}
26+
27+
func unitFollowsQuantity(sentence []string, i int, labelCurr string, labelPrev string) float64 {
28+
if labelPrev == "quantity" && labelCurr == "units" {
29+
return 1
30+
}
31+
32+
return 0
33+
}
34+
35+
func ingredientFollowsUnit(sentence []string, i int, labelCurr string, labelPrev string) float64 {
36+
if labelPrev == "units" && labelCurr == "ingredient" {
37+
return 1
38+
}
39+
40+
return 0
41+
}
42+
43+
func TestFindBestLabelling(t *testing.T) {
44+
sentence := strings.Split("1 cup apples", " ")
45+
labels := []string{"quantity", "units", "ingredient"}
46+
features := []Feature{
47+
{1.0, isQuantityAtBeginning},
48+
{1.0, unitFollowsQuantity},
49+
{1.0, ingredientFollowsUnit},
50+
}
51+
bestLabelling := FindBestLabelling(sentence, labels, features)
52+
53+
expectedLabelling := []string{"quantity", "units", "ingredient"}
54+
55+
if reflect.DeepEqual(bestLabelling, expectedLabelling) != true {
56+
t.Errorf("Expected %v to equal %v", bestLabelling, expectedLabelling)
57+
}
58+
}

0 commit comments

Comments
 (0)