mirror of
https://github.com/johnkerl/miller.git
synced 2026-01-23 02:14:13 +00:00
Add surv Verb to Estimate a Survival Curve (#1788)
Add a surv verb to estimate a survival curve using Kaplan-Meier. It requires duration and status (event or censored) columns, and outputs each distinct duration and corresponding probability of survival.
This commit is contained in:
parent
35c7eeb977
commit
df73ad8ec0
9 changed files with 216 additions and 4 deletions
|
|
@ -70,6 +70,7 @@ var TRANSFORMER_LOOKUP_TABLE = []TransformerSetup{
|
|||
StepSetup,
|
||||
SubSetup,
|
||||
SummarySetup,
|
||||
SurvSetup,
|
||||
TacSetup,
|
||||
TailSetup,
|
||||
TeeSetup,
|
||||
|
|
|
|||
173
pkg/transformers/surv.go
Normal file
173
pkg/transformers/surv.go
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
package transformers
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/johnkerl/miller/v6/pkg/cli"
|
||||
"github.com/johnkerl/miller/v6/pkg/mlrval"
|
||||
"github.com/johnkerl/miller/v6/pkg/types"
|
||||
"github.com/kshedden/statmodel/duration"
|
||||
"github.com/kshedden/statmodel/statmodel"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------
|
||||
const verbNameSurv = "surv"
|
||||
|
||||
// SurvSetup defines the surv verb: Kaplan-Meier survival curve.
|
||||
var SurvSetup = TransformerSetup{
|
||||
Verb: verbNameSurv,
|
||||
UsageFunc: transformerSurvUsage,
|
||||
ParseCLIFunc: transformerSurvParseCLI,
|
||||
IgnoresInput: false,
|
||||
}
|
||||
|
||||
func transformerSurvUsage(o *os.File) {
|
||||
fmt.Fprintf(o, "Usage: %s %s -d {duration-field} -s {status-field}\n", "mlr", verbNameSurv)
|
||||
fmt.Fprint(o, `
|
||||
Estimate Kaplan-Meier survival curve (right-censored).
|
||||
Options:
|
||||
-d {field} Name of duration field (time-to-event or censoring).
|
||||
-s {field} Name of status field (0=censored, 1=event).
|
||||
-h, --help Show this message.
|
||||
`)
|
||||
}
|
||||
|
||||
func transformerSurvParseCLI(
|
||||
pargi *int,
|
||||
argc int,
|
||||
args []string,
|
||||
_ *cli.TOptions,
|
||||
doConstruct bool,
|
||||
) IRecordTransformer {
|
||||
argi := *pargi
|
||||
verb := args[argi]
|
||||
argi++
|
||||
|
||||
var durationField, statusField string
|
||||
|
||||
for argi < argc {
|
||||
opt := args[argi]
|
||||
if !strings.HasPrefix(opt, "-") {
|
||||
break
|
||||
}
|
||||
if opt == "-h" || opt == "--help" {
|
||||
transformerSurvUsage(os.Stdout)
|
||||
os.Exit(0)
|
||||
} else if opt == "-d" {
|
||||
if argi+1 >= argc {
|
||||
fmt.Fprintf(os.Stderr, "mlr %s: %s requires an argument\n", verb, opt)
|
||||
os.Exit(1)
|
||||
}
|
||||
argi++
|
||||
durationField = args[argi]
|
||||
argi++
|
||||
} else if opt == "-s" {
|
||||
if argi+1 >= argc {
|
||||
fmt.Fprintf(os.Stderr, "mlr %s: %s requires an argument\n", verb, opt)
|
||||
os.Exit(1)
|
||||
}
|
||||
argi++
|
||||
statusField = args[argi]
|
||||
argi++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
*pargi = argi
|
||||
if !doConstruct {
|
||||
return nil
|
||||
}
|
||||
if durationField == "" {
|
||||
fmt.Fprintf(os.Stderr, "mlr %s: -d option is required.\n", verbNameSurv)
|
||||
fmt.Fprintf(os.Stderr, "Please see 'mlr %s --help' for more information.\n", verbNameSurv)
|
||||
os.Exit(1)
|
||||
}
|
||||
if statusField == "" {
|
||||
fmt.Fprintf(os.Stderr, "mlr %s: -s option is required.\n", verbNameSurv)
|
||||
fmt.Fprintf(os.Stderr, "Please see 'mlr %s --help' for more information.\n", verbNameSurv)
|
||||
os.Exit(1)
|
||||
}
|
||||
return NewTransformerSurv(durationField, statusField)
|
||||
}
|
||||
|
||||
// TransformerSurv holds fields for surv verb.
|
||||
type TransformerSurv struct {
|
||||
durationField string
|
||||
statusField string
|
||||
times []float64
|
||||
events []bool
|
||||
}
|
||||
|
||||
// NewTransformerSurv constructs a new surv transformer.
|
||||
func NewTransformerSurv(durationField, statusField string) IRecordTransformer {
|
||||
return &TransformerSurv{
|
||||
durationField: durationField,
|
||||
statusField: statusField,
|
||||
times: make([]float64, 0),
|
||||
events: make([]bool, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Transform processes each record or emits results at end-of-stream.
|
||||
func (tr *TransformerSurv) Transform(
|
||||
inrecAndContext *types.RecordAndContext,
|
||||
outputRecordsAndContexts *list.List,
|
||||
inputDownstreamDoneChannel <-chan bool,
|
||||
outputDownstreamDoneChannel chan<- bool,
|
||||
) {
|
||||
HandleDefaultDownstreamDone(inputDownstreamDoneChannel, outputDownstreamDoneChannel)
|
||||
if !inrecAndContext.EndOfStream {
|
||||
rec := inrecAndContext.Record
|
||||
mvDur := rec.Get(tr.durationField)
|
||||
if mvDur == nil {
|
||||
fmt.Fprintf(os.Stderr, "mlr surv: duration field '%s' not found\n", tr.durationField)
|
||||
os.Exit(1)
|
||||
}
|
||||
duration := mvDur.GetNumericToFloatValueOrDie()
|
||||
mvStat := rec.Get(tr.statusField)
|
||||
if mvStat == nil {
|
||||
fmt.Fprintf(os.Stderr, "mlr surv: status field '%s' not found\n", tr.statusField)
|
||||
os.Exit(1)
|
||||
}
|
||||
status := mvStat.GetNumericToFloatValueOrDie() != 0
|
||||
tr.times = append(tr.times, duration)
|
||||
tr.events = append(tr.events, status)
|
||||
} else {
|
||||
// Compute survival using kshedden/statmodel
|
||||
n := len(tr.times)
|
||||
if n == 0 {
|
||||
outputRecordsAndContexts.PushBack(inrecAndContext)
|
||||
return
|
||||
}
|
||||
durations := tr.times
|
||||
statuses := make([]float64, n)
|
||||
for i, ev := range tr.events {
|
||||
if ev {
|
||||
statuses[i] = 1.0
|
||||
} else {
|
||||
statuses[i] = 0.0
|
||||
}
|
||||
}
|
||||
dataCols := [][]float64{durations, statuses}
|
||||
names := []string{tr.durationField, tr.statusField}
|
||||
ds := statmodel.NewDataset(dataCols, names)
|
||||
sf, err := duration.NewSurvfuncRight(ds, tr.durationField, tr.statusField, &duration.SurvfuncRightConfig{})
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "mlr surv: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
sf.Fit()
|
||||
times := sf.Time()
|
||||
survProbs := sf.SurvProb()
|
||||
for i, t := range times {
|
||||
newrec := mlrval.NewMlrmapAsRecord()
|
||||
newrec.PutCopy("time", mlrval.FromFloat(t))
|
||||
newrec.PutCopy("survival", mlrval.FromFloat(survProbs[i]))
|
||||
outputRecordsAndContexts.PushBack(types.NewRecordAndContext(newrec, &inrecAndContext.Context))
|
||||
}
|
||||
outputRecordsAndContexts.PushBack(inrecAndContext)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue