14
14
15
15
#include < lpython/python_ast.h>
16
16
#include < libasr/string_utils.h>
17
+ #include < lpython/parser/parser_exception.h>
17
18
18
19
// This is only used in parser.tab.cc, nowhere else, so we simply include
19
20
// everything from LFortran::AST to save typing:
@@ -664,9 +665,78 @@ static inline ast_t* BOOLOP_01(Allocator &al, Location &loc,
664
665
#define COMPARE (x, op, y, l ) make_Compare_t(p.m_a, l, \
665
666
EXPR (x), cmpopType::op, EXPRS(A2LIST(p.m_a, y)), 1)
666
667
667
- char* concat_string(Allocator &al, ast_t *a, char *b) {
668
- char *s = down_cast2<ConstantStr_t>(a)->m_value ;
669
- return LFortran::s2c (al, std::string (s) + std::string (b));
668
+ static inline ast_t* concat_string(Allocator &al, Location &l,
669
+ expr_t *string, std::string str, expr_t *string_literal) {
670
+ std::string str1 = " " ;
671
+ ast_t * tmp = nullptr ;
672
+ Vec<expr_t *> exprs;
673
+ exprs.reserve (al, 4 );
674
+
675
+ // TODO: Merge two concurrent ConstantStr's into one in the JoinedStr
676
+ if (string_literal) {
677
+ if (is_a<ConstantStr_t>(*string)
678
+ && is_a<ConstantStr_t>(*string_literal)) {
679
+ str1 = std::string (down_cast<ConstantStr_t>(string)->m_value );
680
+ str = std::string (down_cast<ConstantStr_t>(string_literal)->m_value );
681
+ str1 = str1 + str;
682
+ tmp = make_ConstantStr_t (al, l, LFortran::s2c (al, str1), nullptr );
683
+ } else if (is_a<JoinedStr_t>(*string)
684
+ && is_a<JoinedStr_t>(*string_literal)) {
685
+ JoinedStr_t *t = down_cast<JoinedStr_t>(string);
686
+ for (size_t i = 0 ; i < t->n_values ; i++) {
687
+ exprs.push_back (al, t->m_values [i]);
688
+ }
689
+ t = down_cast<JoinedStr_t>(string_literal);
690
+ for (size_t i = 0 ; i < t->n_values ; i++) {
691
+ exprs.push_back (al, t->m_values [i]);
692
+ }
693
+ tmp = make_JoinedStr_t (al, l, exprs.p , exprs.size ());
694
+ } else if (is_a<JoinedStr_t>(*string)
695
+ && is_a<ConstantStr_t>(*string_literal)) {
696
+ JoinedStr_t *t = down_cast<JoinedStr_t>(string);
697
+ for (size_t i = 0 ; i < t->n_values ; i++) {
698
+ exprs.push_back (al, t->m_values [i]);
699
+ }
700
+ exprs.push_back (al, string_literal);
701
+ tmp = make_JoinedStr_t (al, l, exprs.p , exprs.size ());
702
+ } else if (is_a<ConstantStr_t>(*string)
703
+ && is_a<JoinedStr_t>(*string_literal)) {
704
+ exprs.push_back (al, string);
705
+ JoinedStr_t *t = down_cast<JoinedStr_t>(string_literal);
706
+ for (size_t i = 0 ; i < t->n_values ; i++) {
707
+ exprs.push_back (al, t->m_values [i]);
708
+ }
709
+ tmp = make_JoinedStr_t (al, l, exprs.p , exprs.size ());
710
+ } else if (is_a<ConstantBytes_t>(*string)
711
+ && is_a<ConstantBytes_t>(*string_literal)) {
712
+ str1 = std::string (down_cast<ConstantBytes_t>(string)->m_value );
713
+ str1 = str1.substr (0 , str1.size () - 1 );
714
+ str = std::string (down_cast<ConstantBytes_t>(string_literal)->m_value );
715
+ str = str.substr (2 , str.size ());
716
+ str1 = str1 + str;
717
+ tmp = make_ConstantBytes_t (al, l, LFortran::s2c (al, str1), nullptr );
718
+ } else {
719
+ throw LFortran::parser_local::ParserError (
720
+ " The byte and non-byte literals can not be combined" , l);
721
+ }
722
+ } else {
723
+ if (is_a<ConstantStr_t>(*string)) {
724
+ str1 = std::string (down_cast<ConstantStr_t>(string)->m_value );
725
+ str1 = str1 + str;
726
+ tmp = make_ConstantStr_t (al, l, LFortran::s2c (al, str1), nullptr );
727
+ } else if (is_a<JoinedStr_t>(*string)) {
728
+ JoinedStr_t *t = down_cast<JoinedStr_t>(string);
729
+ for (size_t i = 0 ; i < t->n_values ; i++) {
730
+ exprs.push_back (al, t->m_values [i]);
731
+ }
732
+ exprs.push_back (al, (expr_t *)make_ConstantStr_t (al, l,
733
+ LFortran::s2c (al, str), nullptr ));
734
+ tmp = make_JoinedStr_t (al, l, exprs.p , exprs.size ());
735
+ } else {
736
+ LFORTRAN_ASSERT (false );
737
+ }
738
+ }
739
+ return tmp;
670
740
}
671
741
672
742
char * unescape (Allocator &al, LFortran::Str &s) {
@@ -687,8 +757,9 @@ char* unescape(Allocator &al, LFortran::Str &s) {
687
757
// `x.int_n` is of type BigInt but we store the int64_t directly in AST
688
758
#define INTEGER (x, l ) make_ConstantInt_t(p.m_a, l, x, nullptr )
689
759
#define STRING1 (x, l ) make_ConstantStr_t(p.m_a, l, unescape(p.m_a, x), nullptr )
690
- #define STRING2 (x, y, l ) make_ConstantStr_t (p.m_a, l, concat_string(p.m_a, x , y.c_str(p.m_a) ), nullptr )
760
+ #define STRING2 (x, y, l ) concat_string (p.m_a, l, EXPR(x) , y.str( ), nullptr )
691
761
#define STRING3 (id, x, l ) PREFIX_STRING(p.m_a, l, name2char(id), x.c_str(p.m_a))
762
+ #define STRING4 (x, s, l ) concat_string(p.m_a, l, EXPR(x), " " , EXPR(s))
692
763
#define FLOAT (x, l ) make_ConstantFloat_t(p.m_a, l, x, nullptr )
693
764
#define COMPLEX (x, l ) make_ConstantComplex_t(p.m_a, l, 0 , x, nullptr )
694
765
#define BOOL (x, l ) make_ConstantBool_t(p.m_a, l, x, nullptr )
0 commit comments