Skip to content
161 changes: 153 additions & 8 deletions automata/fa/nfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import re
from collections import deque
from itertools import chain, count, product, repeat
from typing import (
Expand Down Expand Up @@ -216,21 +217,165 @@ def from_regex(
Self
The NFA accepting the language of the input regex.
"""
# Import the shorthand character classes
from automata.regex.parser import (
DIGIT_CHARS,
NON_DIGIT_CHARS,
NON_WHITESPACE_CHARS,
NON_WORD_CHARS,
WHITESPACE_CHARS,
WORD_CHARS,
)
Comment on lines +221 to +228
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@StefanosChaliasos Can you please keep all imports at the top of the file? There's no particular need for the tighter scoping here, IMO.

cc @eliotwrobson

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think ruff will complain about the imports.


# Create a set for additional symbols from shorthand classes
additional_symbols = set()

# Check for shorthand classes in the regex
if "\\s" in regex:
additional_symbols.update(WHITESPACE_CHARS)
if "\\S" in regex:
additional_symbols.update(NON_WHITESPACE_CHARS)
if "\\d" in regex:
additional_symbols.update(DIGIT_CHARS)
if "\\D" in regex:
additional_symbols.update(NON_DIGIT_CHARS)
if "\\w" in regex:
additional_symbols.update(WORD_CHARS)
if "\\W" in regex:
additional_symbols.update(NON_WORD_CHARS)
Comment on lines +234 to +245
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@StefanosChaliasos Can you please refactor this to use a dict-based lookup table? That would make this much less repetitive.

cc @eliotwrobson

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this could make things much cleaner, especially since this can be done in a loop 👍🏽


# Extract escaped sequences from the regex
escape_chars = set()
i = 0
while i < len(regex):
if regex[i] == "\\" and i + 1 < len(regex):
from automata.regex.parser import _handle_escape_sequences

# Skip shorthand classes
if regex[i + 1] in "sSwWdD":
i += 2
continue

escaped_char = _handle_escape_sequences(regex[i + 1])
escape_chars.add(escaped_char)
i += 2
else:
i += 1

class_symbols = set()
range_pattern = re.compile(r"\[([^\]]*)\]")
for match in range_pattern.finditer(regex):
class_content = match.group(1)
pos = 0
while pos < len(class_content):
if class_content[pos] == "\\" and pos + 1 < len(class_content):
# Check for shorthand classes in character classes
if class_content[pos + 1] == "s":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might be able to use a lookup table from a dictionary here? Just use the character as a key and a tuple of additional symbols and position increment as the value.

additional_symbols.update(WHITESPACE_CHARS)
pos += 2
continue
elif class_content[pos + 1] == "d":
additional_symbols.update(DIGIT_CHARS)
pos += 2
continue
elif class_content[pos + 1] == "w":
additional_symbols.update(WORD_CHARS)
pos += 2
continue
elif class_content[pos + 1] in "S":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@StefanosChaliasos What is the intention of using in here as opposed to ==? If the right-hand side is just a single character, the only difference that seems to make is permitting class_content[pos + 1] to be empty string (in addition to the character itself). In other words:

"S" in "S" True
"" in "S"  # True

additional_symbols.update(NON_WHITESPACE_CHARS)
pos += 2
continue
elif class_content[pos + 1] in "D":
additional_symbols.update(NON_DIGIT_CHARS)
pos += 2
continue
elif class_content[pos + 1] in "W":
additional_symbols.update(NON_WORD_CHARS)
pos += 2
continue

# Handle escape sequence in character class
from automata.regex.parser import _handle_escape_sequences
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@StefanosChaliasos Can you also please move this import to the top of the file?


escaped_char = _handle_escape_sequences(class_content[pos + 1])
class_symbols.add(escaped_char)

# Check if this is part of a range
if (
pos + 2 < len(class_content)
and class_content[pos + 2] == "-"
and pos + 3 < len(class_content)
):
# Handle range with escaped start character
start_char = escaped_char

# Check if end character is also escaped
if class_content[pos + 3] == "\\" and pos + 4 < len(
class_content
):
end_char = _handle_escape_sequences(class_content[pos + 4])
pos += 5
else:
end_char = class_content[pos + 3]
pos += 4

# Add all characters in the range to input symbols
for i in range(ord(start_char), ord(end_char) + 1):
class_symbols.add(chr(i))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be done with a python update call instead of a loop.

continue

pos += 2
elif pos + 2 < len(class_content) and class_content[pos + 1] == "-":
# Handle normal range
start_char = class_content[pos]

if class_content[pos + 2] == "\\" and pos + 3 < len(class_content):
end_char = _handle_escape_sequences(class_content[pos + 3])
pos += 4
else:
end_char = class_content[pos + 2]
pos += 3

for i in range(ord(start_char), ord(end_char) + 1):
class_symbols.add(chr(i))
else:
if class_content[pos] != "^": # Skip negation symbol
class_symbols.add(class_content[pos])
pos += 1

# Set up the final input symbols
if input_symbols is None:
input_symbols = frozenset(regex) - RESERVED_CHARACTERS
# If no input_symbols provided, collect all non-reserved chars from regex
input_symbols_set = set()
for char in regex:
if char not in RESERVED_CHARACTERS:
input_symbols_set.add(char)

# Include all character class symbols and escape sequences
input_symbols_set.update(class_symbols)
input_symbols_set.update(escape_chars)

# Add the shorthand characters
input_symbols_set.update(additional_symbols)

final_input_symbols = frozenset(input_symbols_set)
else:
conflicting_symbols = RESERVED_CHARACTERS & input_symbols
if conflicting_symbols:
raise exceptions.InvalidSymbolError(
f"Invalid input symbols: {conflicting_symbols}"
)
# For user-provided input_symbols, we need to update
# with character class symbols and escape sequences
final_input_symbols = (
frozenset(input_symbols)
.union(class_symbols)
.union(escape_chars)
.union(additional_symbols)
)

nfa_builder = parse_regex(regex, input_symbols)
# Build the NFA
nfa_builder = parse_regex(regex, final_input_symbols)

return cls(
states=frozenset(nfa_builder._transitions.keys()),
input_symbols=input_symbols,
input_symbols=final_input_symbols,
transitions=nfa_builder._transitions,
initial_state=nfa_builder._initial_state,
final_states=nfa_builder._final_states,
Expand Down
Loading
Loading