diff --git a/go/src/miller/dsl/cst/udf.go b/go/src/miller/dsl/cst/udf.go index 93555b98e..11f00d3d2 100644 --- a/go/src/miller/dsl/cst/udf.go +++ b/go/src/miller/dsl/cst/udf.go @@ -19,6 +19,7 @@ type Signature struct { functionName string arity int // Computable from len(typeGatedParameterNames) at callee, not at caller typeGatedParameterNames []*types.TypeGatedMlrvalName + typeGatedReturnValue *types.TypeGatedMlrvalName // TODO: parameter typedecls // TODO: return-value typedecls @@ -28,11 +29,13 @@ func NewSignature( functionName string, arity int, typeGatedParameterNames []*types.TypeGatedMlrvalName, + typeGatedReturnValue *types.TypeGatedMlrvalName, ) *Signature { return &Signature{ functionName: functionName, arity: arity, typeGatedParameterNames: typeGatedParameterNames, + typeGatedReturnValue: typeGatedReturnValue, } } @@ -58,7 +61,7 @@ func NewUnresolvedUDF( functionName string, callsiteArity int, ) *UDF { - signature := NewSignature(functionName, callsiteArity, nil) + signature := NewSignature(functionName, callsiteArity, nil, nil) udf := NewUDF(signature, nil) return udf } @@ -132,6 +135,16 @@ func (this *UDFCallsite) Evaluate(state *State) types.Mlrval { // their UDF but we lost the return value. lib.InternalCodingErrorIf(blockExitPayload.blockReturnValue == nil) + err = this.udf.signature.typeGatedReturnValue.Check(blockExitPayload.blockReturnValue) + if err != nil { + // TODO: put error-return in the Evaluate API + fmt.Fprint( + os.Stderr, + err, + ) + os.Exit(1) + } + return *blockExitPayload.blockReturnValue } @@ -216,7 +229,17 @@ func (this *RootNode) BuildAndInstallUDF(astNode *dsl.ASTNode) error { functionName := string(astNode.Token.Lit) parameterListASTNode := astNode.Children[0] functionBodyASTNode := astNode.Children[1] - // TODO: optional typedecl 3rd arg + + returnValueTypeName := "var" + if len(astNode.Children) == 3 { + typeNode := astNode.Children[2] + lib.InternalCodingErrorIf(typeNode.Type != dsl.NodeTypeTypedecl) + returnValueTypeName = string(typeNode.Token.Lit) + } + typeGatedReturnValue, err := types.NewTypeGatedMlrvalName( + "function return value", + returnValueTypeName, + ) lib.InternalCodingErrorIf(parameterListASTNode.Type != dsl.NodeTypeParameterList) lib.InternalCodingErrorIf(parameterListASTNode.Children == nil) @@ -248,7 +271,7 @@ func (this *RootNode) BuildAndInstallUDF(astNode *dsl.ASTNode) error { typeGatedParameterNames[i] = typeGatedParameterName } - signature := NewSignature(functionName, arity, typeGatedParameterNames) + signature := NewSignature(functionName, arity, typeGatedParameterNames, typeGatedReturnValue) functionBody, err := this.BuildStatementBlockNode(functionBodyASTNode) if err != nil { diff --git a/go/u/try-cst b/go/u/try-cst index c5148330a..04b383277 100755 --- a/go/u/try-cst +++ b/go/u/try-cst @@ -429,3 +429,9 @@ run_mlr --from u/s.dkvp put 'func f(int x) { return 2*x} $y=f(3)' run_mlr --from u/s.dkvp put 'func f(num x) { return 2*x} $y=f(3)' mlr_expect_fail --from u/s.dkvp put 'func f(str x) { return 2*x} $y=f(3)' mlr_expect_fail --from u/s.dkvp put 'func f(arr x) { return 2*x} $y=f(3)' + +run_mlr --from u/s.dkvp put 'func f(x): var { return 2*x} $y=f(3)' +run_mlr --from u/s.dkvp put 'func f(x): int { return 2*x} $y=f(3)' +run_mlr --from u/s.dkvp put 'func f(x): num { return 2*x} $y=f(3)' +mlr_expect_fail --from u/s.dkvp put 'func f(x): str { return 2*x} $y=f(3)' +mlr_expect_fail --from u/s.dkvp put 'func f(x): arr { return 2*x} $y=f(3)' diff --git a/go/u/try-cst.out b/go/u/try-cst.out index 2a631b9b6..91362acbe 100644 --- a/go/u/try-cst.out +++ b/go/u/try-cst.out @@ -3150,3 +3150,49 @@ a=wye,b=wye,i=3,x=0.20460330576630303,y=0.33831852551664776 a=eks,b=wye,i=4,x=0.38139939387114097,y=0.13418874328430463 mlr --from u/s.dkvp put str x = 3 mlr --from u/s.dkvp put arr x = 3 + +---------------------------------------------------------------- +mlr --from u/s.dkvp put func f(var x) { return 2*x} $y=f(3) +a=pan,b=pan,i=1,x=0.3467901443380824,y=6 +a=eks,b=pan,i=2,x=0.7586799647899636,y=6 +a=wye,b=wye,i=3,x=0.20460330576630303,y=6 +a=eks,b=wye,i=4,x=0.38139939387114097,y=6 + +---------------------------------------------------------------- +mlr --from u/s.dkvp put func f(int x) { return 2*x} $y=f(3) +a=pan,b=pan,i=1,x=0.3467901443380824,y=6 +a=eks,b=pan,i=2,x=0.7586799647899636,y=6 +a=wye,b=wye,i=3,x=0.20460330576630303,y=6 +a=eks,b=wye,i=4,x=0.38139939387114097,y=6 + +---------------------------------------------------------------- +mlr --from u/s.dkvp put func f(num x) { return 2*x} $y=f(3) +a=pan,b=pan,i=1,x=0.3467901443380824,y=6 +a=eks,b=pan,i=2,x=0.7586799647899636,y=6 +a=wye,b=wye,i=3,x=0.20460330576630303,y=6 +a=eks,b=wye,i=4,x=0.38139939387114097,y=6 +mlr --from u/s.dkvp put func f(str x) { return 2*x} $y=f(3) +mlr --from u/s.dkvp put func f(arr x) { return 2*x} $y=f(3) + +---------------------------------------------------------------- +mlr --from u/s.dkvp put func f(x): var { return 2*x} $y=f(3) +a=pan,b=pan,i=1,x=0.3467901443380824,y=6 +a=eks,b=pan,i=2,x=0.7586799647899636,y=6 +a=wye,b=wye,i=3,x=0.20460330576630303,y=6 +a=eks,b=wye,i=4,x=0.38139939387114097,y=6 + +---------------------------------------------------------------- +mlr --from u/s.dkvp put func f(x): int { return 2*x} $y=f(3) +a=pan,b=pan,i=1,x=0.3467901443380824,y=6 +a=eks,b=pan,i=2,x=0.7586799647899636,y=6 +a=wye,b=wye,i=3,x=0.20460330576630303,y=6 +a=eks,b=wye,i=4,x=0.38139939387114097,y=6 + +---------------------------------------------------------------- +mlr --from u/s.dkvp put func f(x): num { return 2*x} $y=f(3) +a=pan,b=pan,i=1,x=0.3467901443380824,y=6 +a=eks,b=pan,i=2,x=0.7586799647899636,y=6 +a=wye,b=wye,i=3,x=0.20460330576630303,y=6 +a=eks,b=wye,i=4,x=0.38139939387114097,y=6 +mlr --from u/s.dkvp put func f(x): str { return 2*x} $y=f(3) +mlr --from u/s.dkvp put func f(x): arr { return 2*x} $y=f(3)