Uploaded image for project: 'Apache MADlib'
  1. Apache MADlib
  2. MADLIB-1217

DT and RF: null_as_category null handling does not work with grouping

    XMLWordPrintableJSON

Details

    Description

      Load data:

      DROP TABLE IF EXISTS mt_cars;
      
      CREATE TABLE mt_cars (
          id integer NOT NULL,
          mpg double precision,
          cyl integer,
          disp double precision,
          hp integer,
          drat double precision,
          wt double precision,
          qsec double precision,
          vs integer,
          am integer,
          gear integer,
          carb integer
      );
      
      INSERT INTO mt_cars VALUES
      (1,18.7,8,360,175,3.15,3.44,17.02,0,0,3,2),
      (2,21,6,160,110,3.9,2.62,16.46,0,1,4,4),
      (3,24.4,4,146.7,62,3.69,3.19,20,1,0,4,2),
      (4,21,6,160,110,3.9,2.875,17.02,0,1,4,4),
      (5,17.8,6,167.6,123,3.92,3.44,18.9,1,0,4,4),
      (6,16.4,8,275.8,180,3.078,4.07,17.4,0,0,3,3),
      (7,22.8,4,108,93,3.85,2.32,18.61,1,1,4,1),
      (8,17.3,8,275.8,180,3.078,3.73,17.6,0,0,3,3),
      (9,21.4,null,258,110,3.08,3.215,19.44,1,0,3,1),
      (10,15.2,8,275.8,180,3.078,3.78,18,0,0,3,3),
      (11,18.1,6,225,105,2.768,3.46,20.22,1,0,3,1),
      (12,32.4,4,78.7,66,4.08,2.20,19.47,1,1,4,1),
      (13,14.3,8,360,245,3.21,3.578,15.84,0,0,3,4),
      (14,22.8,4,140.8,95,3.92,3.15,22.9,1,0,4,2),
      (15,30.4,4,75.7,52,4.93,1.615,18.52,1,1,4,2),
      (16,19.2,6,167.6,123,3.92,3.44,18.3,1,0,4,4),
      (17,33.9,4,71.14,65,4.22,1.835,19.9,1,1,4,1),
      (18,15.2,null,304,150,3.15,3.435,17.3,0,0,3,2),
      (19,10.4,8,472,205,2.93,5.25,17.98,0,0,3,4),
      (20,27.3,4,79,66,4.08,1.935,18.9,1,1,4,1),
      (21,10.4,8,460,215,3,5.424,17.82,0,0,3,4),
      (22,26,4,120.3,91,4.43,2.14,16.7,0,1,5,2),
      (23,14.7,8,440,230,3.23,5.345,17.42,0,0,3,4),
      (24,30.4,4,95.14,113,3.77,1.513,16.9,1,1,5,2),
      (25,21.5,4,120.1,97,3.70,2.465,20.01,1,0,3,1),
      (26,15.8,8,351,264,4.22,3.17,14.5,0,1,5,4),
      (27,15.5,8,318,150,2.768,3.52,16.87,0,0,3,2),
      (28,15,8,301,335,3.54,3.578,14.6,0,1,5,8),
      (29,13.3,8,350,245,3.73,3.84,15.41,0,0,3,4),
      (30,19.2,8,400,175,3.08,3.845,17.05,0,0,3,2),
      (31,19.7,6,145,175,3.62,2.77,15.5,0,1,5,6),
      (32,21.4,4,121,109,4.11,2.78,18.6,1,1,4,2);
      

      DT:

      DROP TABLE IF EXISTS train_output, train_output_summary, train_output_cv;
      
      SELECT madlib.tree_train('mt_cars',         -- source table
                               'train_output',    -- output model table
                               'id',              -- id column
                               'mpg',             -- dependent variable
                               '*',               -- features
                               'id, hp, drat, am, gear, carb',  -- exclude columns
                               'mse',             -- split criterion
                               'am',        -- grouping
                               NULL::text,        -- no weights, all observations treated equally
                               10,                -- max depth
                               8,                 -- min split
                               3,                 -- number of bins per continuous variable
                               10,                -- number of splits
                               NULL,              -- pruning parameters
                               'null_as_category=TRUE'
                               );
      

      results in error

      ERROR:  plpy.SPIError: column "__null__" does not exist
      LINE 14:                             (COALESCE(vs, __NULL__))::text a...
                                                         ^
      QUERY:  
                      SELECT
                          colname::text,
                          levels::text[],
                          grp_key::text
                      from (
                          SELECT
                              grp_key,
                              'vs' as colname,
                              array_agg(levels order by dep_avg) as levels
                          from (
                              SELECT
                                  array_to_string(array[(am)::text]::text[], ',') as grp_key,
                                  (COALESCE(vs, __NULL__))::text as levels,
                                  COALESCE(vs, __NULL__) as dep_avg
                              FROM mt_cars
                              WHERE (am) is not NULL and (mpg) is not NULL AND COALESCE(vs, __NULL__) IS NOT NULL
                              GROUP BY COALESCE(vs, __NULL__), am
                          ) s
                          GROUP BY grp_key
                      ) s1
                      where array_upper(levels, 1) > 1
                       UNION ALL 
                      SELECT
                          colname::text,
                          levels::text[],
                          grp_key::text
                      from (
                          SELECT
                              grp_key,
                              'cyl' as colname,
                              array_agg(levels order by dep_avg) as levels
                          from (
                              SELECT
                                  array_to_string(array[(am)::text]::text[], ',') as grp_key,
                                  (COALESCE(cyl, __NULL__))::text as levels,
                                  COALESCE(cyl, __NULL__) as dep_avg
                              FROM mt_cars
                              WHERE (am) is not NULL and (mpg) is not NULL AND COALESCE(cyl, __NULL__) IS NOT NULL
                              GROUP BY COALESCE(cyl, __NULL__), am
                          ) s
                          GROUP BY grp_key
                      ) s1
                      where array_upper(levels, 1) > 1
                      
      CONTEXT:  Traceback (most recent call last):
        PL/Python function "tree_train", line 28, in <module>
          null_handling_params, verbose_mode)
        PL/Python function "tree_train", line 489, in tree_train
        PL/Python function "tree_train", line 293, in _get_tree_states
        PL/Python function "tree_train", line 963, in _get_bins_grps
      PL/Python function "tree_train"
      

      RF:

      DROP TABLE IF EXISTS mt_cars_output, mt_cars_output_group, mt_cars_output_summary;
      
      SELECT madlib.forest_train('mt_cars',
                                 'mt_cars_output',
                                 'id',
                                 'mpg',
                                 '*',
                                 'id, hp, drat, am, gear, carb',  -- exclude columns
                                 'am', -- grouping
                                 10::integer,
                                 2::integer,
                                 TRUE::boolean,
                                 1,
                                 10,
                                 8,
                                 3,
                                 10,
                                 'null_as_category=TRUE'
                                 );
      

      results in error:

      ERROR:  plpy.SPIError: column "__null__" does not exist
      LINE 14:                             (COALESCE(vs, __NULL__))::text a...
                                                         ^
      QUERY:  
                      SELECT
                          colname::text,
                          levels::text[],
                          grp_key::text
                      from (
                          SELECT
                              grp_key,
                              'vs' as colname,
                              array_agg(levels order by dep_avg) as levels
                          from (
                              SELECT
                                  array_to_string(array[(am)::text], ',') as grp_key,
                                  (COALESCE(vs, __NULL__))::text as levels,
                                  COALESCE(vs, __NULL__) as dep_avg
                              FROM mt_cars
                              WHERE (am) is not NULL and (mpg) is not NULL AND COALESCE(vs, __NULL__) IS NOT NULL
                              GROUP BY COALESCE(vs, __NULL__), am
                          ) s
                          GROUP BY grp_key
                      ) s1
                      where array_upper(levels, 1) > 1
                       UNION ALL 
                      SELECT
                          colname::text,
                          levels::text[],
                          grp_key::text
                      from (
                          SELECT
                              grp_key,
                              'cyl' as colname,
                              array_agg(levels order by dep_avg) as levels
                          from (
                              SELECT
                                  array_to_string(array[(am)::text], ',') as grp_key,
                                  (COALESCE(cyl, __NULL__))::text as levels,
                                  COALESCE(cyl, __NULL__) as dep_avg
                              FROM mt_cars
                              WHERE (am) is not NULL and (mpg) is not NULL AND COALESCE(cyl, __NULL__) IS NOT NULL
                              GROUP BY COALESCE(cyl, __NULL__), am
                          ) s
                          GROUP BY grp_key
                      ) s1
                      where array_upper(levels, 1) > 1
                      
      CONTEXT:  Traceback (most recent call last):
        PL/Python function "forest_train", line 42, in <module>
          sample_ratio
        PL/Python function "forest_train", line 427, in forest_train
        PL/Python function "forest_train", line 963, in _get_bins_grps
      PL/Python function "forest_train"
      

      Attachments

        Activity

          People

            riyer Rahul Iyer
            fmcquillan Frank McQuillan
            Votes:
            0 Vote for this issue
            Watchers:
            3 Start watching this issue

            Dates

              Created:
              Updated:
              Resolved: