Skip to content

Fortran support #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
2 changes: 2 additions & 0 deletions codebleu/dataflow_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DFG_python,
DFG_ruby,
DFG_rust,
DFG_fortran,
index_to_code_token,
remove_comments_and_docstrings,
tree_to_token_index,
Expand All @@ -30,6 +31,7 @@
"c": DFG_csharp, # XLCoST uses C# parser for C
"cpp": DFG_csharp, # XLCoST uses C# parser for C++
"rust": DFG_rust,
"fortran": DFG_fortran,
}


Expand Down
78 changes: 78 additions & 0 deletions codebleu/keywords/fortran.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
allocatable
allocate
assignment
backspace
call
case
character
close
complex
contains
cycle
deallocate
default
do
else
else if
elsewhere
end do
end function
end if
end interface
end module
end program
end select
end subroutine
end type
end where
endfile
exit
format
function
if
implicit
in
inout
integer
intent
interface
intrinsic
inquire
kind
len
logical
module
namelist
nullify
only
open
operator
optional
out
print
pointer
private
program
public
read
real
recursive
result
return
rewind
select case
stop
subroutine
target
then
type
use
where
while
write






159 changes: 159 additions & 0 deletions codebleu/parser/DFG.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,3 +1385,162 @@ def DFG_rust(root_node, index_to_code, states):
DFG += temp

return sorted(DFG, key=lambda x: x[1]), states

def DFG_fortran(root_node, index_to_code, states):
assignment = ["assignment_stmt"]
def_statement = ["declaration_stmt"]
increment_statement = ["arithmetic_if_stmt"]
if_statement = ["if_stmt", "else"]
for_statement = ["do_loop"]
while_statement = ["while_loop"]
do_first_statement = []

states = states.copy()

if (
len(root_node.children) == 0 or root_node.type in ["string_literal", "character_literal"]
) and root_node.type != "comment":
idx, code = index_to_code[(root_node.start_point, root_node.end_point)]
if root_node.type == code:
return [], states
elif code in states:
return [(code, idx, "comesFrom", [code], states[code].copy())], states
else:
if root_node.type == "identifier":
states[code] = [idx]
return [(code, idx, "comesFrom", [], [])], states

elif root_node.type in def_statement:
name = root_node.child_by_field_name("name")
value = root_node.child_by_field_name("value")
DFG = []
if value is None:
indexs = tree_to_variable_index(name, index_to_code)
for index in indexs:
idx, code = index_to_code[index]
DFG.append((code, idx, "comesFrom", [], []))
states[code] = [idx]
return sorted(DFG, key=lambda x: x[1]), states
else:
name_indexs = tree_to_variable_index(name, index_to_code)
value_indexs = tree_to_variable_index(value, index_to_code)
temp, states = DFG_fortran(value, index_to_code, states)
DFG += temp
for index1 in name_indexs:
idx1, code1 = index_to_code[index1]
for index2 in value_indexs:
idx2, code2 = index_to_code[index2]
DFG.append((code1, idx1, "comesFrom", [code2], [idx2]))
states[code1] = [idx1]
return sorted(DFG, key=lambda x: x[1]), states

elif root_node.type in assignment:
left_nodes = root_node.child_by_field_name("left")
right_nodes = root_node.child_by_field_name("right")
DFG = []
temp, states = DFG_fortran(right_nodes, index_to_code, states)
DFG += temp
name_indexs = tree_to_variable_index(left_nodes, index_to_code)
value_indexs = tree_to_variable_index(right_nodes, index_to_code)
for index1 in name_indexs:
idx1, code1 = index_to_code[index1]
for index2 in value_indexs:
idx2, code2 = index_to_code[index2]
DFG.append((code1, idx1, "computedFrom", [code2], [idx2]))
states[code1] = [idx1]
return sorted(DFG, key=lambda x: x[1]), states

elif root_node.type in increment_statement:
DFG = []
indexs = tree_to_variable_index(root_node, index_to_code)
for index1 in indexs:
idx1, code1 = index_to_code[index1]
for index2 in indexs:
idx2, code2 = index_to_code[index2]
DFG.append((code1, idx1, "computedFrom", [code2], [idx2]))
states[code1] = [idx1]
return sorted(DFG, key=lambda x: x[1]), states

elif root_node.type in if_statement:
DFG = []
current_states = states.copy()
others_states = []
flag = False
tag = False
if "else" in root_node.type:
tag = True
for child in root_node.children:
if "else" in child.type:
tag = True
if child.type not in if_statement and flag is False:
temp, current_states = DFG_fortran(child, index_to_code, current_states)
DFG += temp
else:
flag = True
temp, new_states = DFG_fortran(child, index_to_code, states)
DFG += temp
others_states.append(new_states)
others_states.append(current_states)
if tag is False:
others_states.append(states)
new_states = {}
for dic in others_states:
for key in dic:
if key not in new_states:
new_states[key] = dic[key].copy()
else:
new_states[key] += dic[key]
for key in new_states:
new_states[key] = sorted(list(set(new_states[key])))
return sorted(DFG, key=lambda x: x[1]), new_states

elif root_node.type in for_statement:
DFG = []
for child in root_node.children:
temp, states = DFG_fortran(child, index_to_code, states)
DFG += temp
flag = False
for child in root_node.children:
if flag:
temp, states = DFG_fortran(child, index_to_code, states)
DFG += temp
elif child.type == "local_variable_declaration":
flag = True
dic = {}
for x in DFG:
if (x[0], x[1], x[2]) not in dic:
dic[(x[0], x[1], x[2])] = [x[3], x[4]]
else:
dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
return sorted(DFG, key=lambda x: x[1]), states

elif root_node.type in while_statement:
DFG = []
for i in range(2):
for child in root_node.children:
temp, states = DFG_fortran(child, index_to_code, states)
DFG += temp
dic = {}
for x in DFG:
if (x[0], x[1], x[2]) not in dic:
dic[(x[0], x[1], x[2])] = [x[3], x[4]]
else:
dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
return sorted(DFG, key=lambda x: x[1]), states

else:
DFG = []
for child in root_node.children:
if child.type in do_first_statement:
temp, states = DFG_fortran(child, index_to_code, states)
DFG += temp
for child in root_node.children:
if child.type not in do_first_statement:
temp, states = DFG_fortran(child, index_to_code, states)
DFG += temp

return sorted(DFG, key=lambda x: x[1]), states
1 change: 1 addition & 0 deletions codebleu/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DFG_python,
DFG_ruby,
DFG_rust,
DFG_fortran,
)
from .utils import (
index_to_code_token,
Expand Down
5 changes: 5 additions & 0 deletions codebleu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"go",
"ruby",
"rust",
"fortran"
] # keywords available


Expand Down Expand Up @@ -173,6 +174,10 @@ def get_tree_sitter_language(lang: str) -> Language:
import tree_sitter_rust

return Language(tree_sitter_rust.language())
elif lang == "fortran":
import tree_sitter_fortran

return Language(tree_sitter_fortran.language())
else:
assert False, "Not reachable"
except ImportError:
Expand Down
1 change: 1 addition & 0 deletions tests/test_codebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_exact_match_works_for_all_langs(lang: str) -> None:
("go", ["func foo ( x ) { return x }"], ["func bar ( y ) {\n return y\n}"]),
("ruby", ["def foo ( x ) :\n return x"], ["def bar ( y ) :\n return y"]),
("rust", ["fn foo ( x ) -> i32 { x }"], ["fn bar ( y ) -> i32 { y }"]),
("fortran", ["function foo ( x ) result ( x )\n end function foo"], ["function bar ( y ) result ( y )\n end function bar"]),
],
)
def test_simple_cases_work_for_all_langs(lang: str, predictions: List[Any], references: List[Any]) -> None:
Expand Down
31 changes: 31 additions & 0 deletions use.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# from codebleu import calc_codebleu

# prediction = "public function add(a,b) { return (a+b) }"
# reference = "public function sum(a,b) { return (a+b) }"

# result = calc_codebleu([reference], [prediction], lang="java", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None)
# print(result)

# from codebleu import calc_codebleu

# prediction = "def add ( a , b ) :\n return a + b"
# reference = "def sum ( first , second ) :\n return second + first"

# result = calc_codebleu([reference], [prediction], lang="python", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None)
# print(result)

from codebleu import calc_codebleu

prediction = """function foo(x)
real :: foo
foo = x
end function foo
"""
reference = """function foo(x)
real :: foo
foo = x
end function foo
"""

result = calc_codebleu([reference], [prediction], lang="fortran", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None)
print(result)