@@ -638,7 +638,8 @@ def check_csv_files(data_dir, is_3d=False, dir_name=None):
638
638
# Discard first index column to not have error if it is not sorted
639
639
p_number = df .iloc [:, 0 ].to_list ()
640
640
df = df .rename (columns = lambda x : x .strip ()) # trim spaces in column names
641
- cols_not_in_file = [x for x in req_columns if x not in df .columns ]
641
+ columns_present = [x .lower () for x in df .columns ]
642
+ cols_not_in_file = [x for x in req_columns if x not in columns_present ]
642
643
if len (cols_not_in_file ) > 0 :
643
644
if len (cols_not_in_file ) == 1 :
644
645
error_message = f"'{ cols_not_in_file [0 ]} ' column is not present in CSV file:\n { csv_path } "
@@ -647,7 +648,7 @@ def check_csv_files(data_dir, is_3d=False, dir_name=None):
647
648
return True , error_message , {}
648
649
649
650
# Check class in columns
650
- if 'class' in df . columns :
651
+ if 'class' in columns_present :
651
652
df ["class" ] = df ["class" ].astype ("int" )
652
653
class_point = np .array (df ["class" ])
653
654
@@ -661,10 +662,9 @@ def check_csv_files(data_dir, is_3d=False, dir_name=None):
661
662
662
663
nclasses = uniq .max ()
663
664
664
-
665
665
constraints = {}
666
- if 'class' in df . columns :
667
- constraints ["MODEL.N_CLASSES" ] = nclasses
666
+ if 'class' in columns_present :
667
+ constraints ["MODEL.N_CLASSES" ] = int ( nclasses )
668
668
if dir_name is not None :
669
669
constraints [dir_name ] = len (ids )
670
670
constraints [dir_name + "_path" ] = data_dir
0 commit comments