Skip to content

API

Training

MLJModelInterface.fit Function
julia
function fit(
    config::NeuroTypes,
    dtrain;
    feature_names,
    target_name,
    weight_name=nothing,
    offset_name=nothing,
    deval=nothing,
    metric=nothing,
    print_every_n=9999,
    early_stopping_rounds=9999,
    verbosity=1,
    device=:cpu,
    gpuID=0,
)

Training function of NeuroTabModels' internal API.

Arguments

  • config::LearnerTypes

  • dtrain: Must be <:AbstractDataFrame

Keyword arguments

  • feature_names: Required kwarg, a Vector{Symbol} or Vector{String} of the feature names.

  • target_name Required kwarg, a Symbol or String indicating the name of the target variable.

  • weight_name=nothing

  • offset_name=nothing

  • deval=nothing Data for tracking evaluation metric and perform early stopping.

  • print_every_n=9999

  • verbosity=1

source

Inference

NeuroTabModels.Infer.infer Function

infer(m::NeuroTabModel, data)

Return the inference of a NeuroTabModel over data, where data is AbstractDataFrame.

source