Skip to content

Commit f35286c

Browse files
committed
Ensure detection classes are correctly taken into account
1 parent f17b845 commit f35286c

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

biapy/biapy_aux_functions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,8 @@ def check_csv_files(data_dir, is_3d=False, dir_name=None):
638638
# Discard first index column to not have error if it is not sorted
639639
p_number = df.iloc[:, 0].to_list()
640640
df = df.rename(columns=lambda x: x.strip()) # trim spaces in column names
641-
cols_not_in_file = [x for x in req_columns if x not in df.columns]
641+
columns_present = [x.lower() for x in df.columns]
642+
cols_not_in_file = [x for x in req_columns if x not in columns_present]
642643
if len(cols_not_in_file) > 0:
643644
if len(cols_not_in_file) == 1:
644645
error_message = f"'{cols_not_in_file[0]}' column is not present in CSV file:\n{csv_path}"
@@ -647,7 +648,7 @@ def check_csv_files(data_dir, is_3d=False, dir_name=None):
647648
return True, error_message, {}
648649

649650
# Check class in columns
650-
if 'class' in df.columns:
651+
if 'class' in columns_present:
651652
df["class"] = df["class"].astype("int")
652653
class_point = np.array(df["class"])
653654

@@ -661,10 +662,9 @@ def check_csv_files(data_dir, is_3d=False, dir_name=None):
661662

662663
nclasses = uniq.max()
663664

664-
665665
constraints = {}
666-
if 'class' in df.columns:
667-
constraints["MODEL.N_CLASSES"] = nclasses
666+
if 'class' in columns_present:
667+
constraints["MODEL.N_CLASSES"] = int(nclasses)
668668
if dir_name is not None:
669669
constraints[dir_name] = len(ids)
670670
constraints[dir_name+"_path"] = data_dir

0 commit comments

Comments
 (0)