diff --git a/pkg/pattern/template/point.go b/pkg/pattern/template/point.go index 024baad..3075265 100644 --- a/pkg/pattern/template/point.go +++ b/pkg/pattern/template/point.go @@ -19,6 +19,8 @@ var ( // ErrRelativePointRecursion is returned when a points are relative to itself. ErrRelativePointRecursion = errors.New("point cannot be relative to itself") + + ErrInvalidArguments = errors.New("invalid arguments to call function") ) // Points contains a map with points. @@ -40,52 +42,50 @@ var ErrInvalidPointID = errors.New("type cannot be converted to a PointID") func (p Points) Functions(pat *pattern.Pattern) map[string]govaluate.ExpressionFunction { return map[string]govaluate.ExpressionFunction{ "DistanceBetween": func(args ...interface{}) (interface{}, error) { - id0, err := toPointID(args[0]) - if err != nil { - return nil, fmt.Errorf("parsing args[0] to pointID: %w", err) - } - - id1, err := toPointID(args[1]) - if err != nil { - return nil, fmt.Errorf("parsing args[0] to pointID: %w", err) - } - - p0, err := p.getOrCreate(id0, pat, 0) - if err != nil { - return nil, fmt.Errorf("get or create point %q: %w", id0, err) + if len(args) != 2 { + return nil, fmt.Errorf("function DistanceBetween() requires 2 arguments: %w", + ErrInvalidArguments) } - p1, err := p.getOrCreate(id1, pat, 0) + points, err := p.getOrCreateFromArgs(pat, args...) if err != nil { - return nil, fmt.Errorf("get or create point %q: %w", id1, err) + return nil, err } - - return p0.Position().Distance(p1.Position()), nil + return points[0].Position().Distance(points[1].Position()), nil }, "AngleBetween": func(args ...interface{}) (interface{}, error) { - id0, err := toPointID(args[0]) - if err != nil { - return nil, fmt.Errorf("parsing args[0] to pointID: %w", err) + if len(args) != 2 { + return nil, fmt.Errorf("function AngleBetween() requires 2 arguments: %w", + ErrInvalidArguments) } - id1, err := toPointID(args[1]) + points, err := p.getOrCreateFromArgs(pat, args...) if err != nil { - return nil, fmt.Errorf("parsing args[0] to pointID: %w", err) + return nil, err } - p0, err := p.getOrCreate(id0, pat, 0) - if err != nil { - return nil, fmt.Errorf("get or create point %q: %w", id0, err) - } + return points[0].Vector().AngleBetween(points[1].Vector()), nil + }, + } +} - p1, err := p.getOrCreate(id1, pat, 0) - if err != nil { - return nil, fmt.Errorf("get or create point %q: %w", id1, err) - } +func (p Points) getOrCreateFromArgs(pat *pattern.Pattern, args ...interface{}) ([]point.Point, error) { + points := make([]point.Point, 0, len(args)) + for i, arg := range args { + id, err := toPointID(arg) + if err != nil { + return nil, fmt.Errorf("parsing args[%d] to pointID: %w", i, err) + } - return p0.Vector().AngleBetween(p1.Vector()), nil - }, + newPoint, err := p.getOrCreate(id, pat, 0) + if err != nil { + return nil, fmt.Errorf("get or create point %q: %w", id, err) + } + + points = append(points, newPoint) } + + return points, nil } func toPointID(arg interface{}) (point.ID, error) { @@ -111,9 +111,19 @@ type BetweenPoint struct { func (p Points) evaluationFunctions() map[string]govaluate.ExpressionFunction { return map[string]govaluate.ExpressionFunction{ "acos": func(args ...interface{}) (interface{}, error) { + if len(args) != 1 { + return nil, fmt.Errorf("function acos() requires 1 argument: %w", + ErrInvalidArguments) + } + return math.Acos(args[0].(float64)), nil }, "atan2": func(args ...interface{}) (interface{}, error) { + if len(args) != 2 { + return nil, fmt.Errorf("function atan2() requires 2 arguments: %w", + ErrInvalidArguments) + } + return math.Atan2(args[0].(float64), args[1].(float64)), nil }, }