Resp/Req rewriting works
[metaproxy-moved-to-github.git] / src / test_filter_rewrite.cpp
1 /* This file is part of Metaproxy.
2    Copyright (C) 2005-2013 Index Data
3
4 Metaproxy is free software; you can redistribute it and/or modify it under
5 the terms of the GNU General Public License as published by the Free
6 Software Foundation; either version 2, or (at your option) any later
7 version.
8
9 Metaproxy is distributed in the hope that it will be useful, but WITHOUT ANY
10 WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program; if not, write to the Free Software
16 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17 */
18
19 #include "config.hpp"
20 #include <iostream>
21 #include <stdexcept>
22
23 #include "filter_http_client.hpp"
24 #include <metaproxy/util.hpp>
25 #include "router_chain.hpp"
26 #include <metaproxy/package.hpp>
27
28 #include <boost/regex.hpp>
29 #include <boost/lexical_cast.hpp>
30
31 #define BOOST_AUTO_TEST_MAIN
32 #define BOOST_TEST_DYN_LINK
33
34 #include <boost/test/auto_unit_test.hpp>
35
36 using namespace boost::unit_test;
37 namespace mp = metaproxy_1;
38
39 class FilterHeaderRewrite: public mp::filter::Base {
40 public:
41     void process(mp::Package & package) const {
42         Z_GDU *gdu = package.request().get();
43         //map of request/response vars
44         std::map<std::string, std::string> vars;
45         //we have an http req
46         if (gdu && gdu->which == Z_GDU_HTTP_Request)
47         {
48             std::cout << ">> Request headers" << std::endl;
49             Z_HTTP_Request *hreq = gdu->u.HTTP_Request;
50             mp::odr o;
51             //iterate headers
52             for (Z_HTTP_Header *header = hreq->headers;
53                     header != 0; 
54                     header = header->next) 
55             {
56                 std::cout << header->name << ": " << header->value << std::endl;
57                 rewrite_header(o, vars, header, req_uri_rx, req_uri_pat);
58             }
59             package.request() = gdu;
60         }
61         package.move();
62         gdu = package.response().get();
63         if (gdu && gdu->which == Z_GDU_HTTP_Response)
64         {
65             std::cout << "<< Respose headers" << std::endl;
66             Z_HTTP_Response *hr = gdu->u.HTTP_Response;
67             mp::odr o;
68             //iterate headers
69             for (Z_HTTP_Header *header = hr->headers;
70                     header != 0; 
71                     header = header->next) 
72             {
73                 std::cout << header->name << ": " << header->value << std::endl;
74                 rewrite_header(o, vars, header, resp_uri_rx, resp_uri_pat);
75             }
76             package.response() = gdu;
77         }
78     };
79
80     void configure(const xmlNode* ptr, bool test_only, const char *path) {};
81
82     void rewrite_header(mp::odr & o, 
83             std::map<std::string, std::string> & vars,
84             Z_HTTP_Header *header,
85             const std::string & uri_re,
86             const std::string & uri_pat) const
87     {
88         //exec regex against value
89         boost::regex re(uri_re);
90         boost::smatch what;
91         std::string hvalue(header->value);
92         std::string::const_iterator start, end;
93         start = hvalue.begin();
94         end = hvalue.end();
95         while (regex_search(start, end, what, re)) //find next full match
96         {
97             unsigned i;
98             for (i = 1; i < what.size(); ++i)
99             {
100                 //check if the group is named
101                 std::map<int, std::string>::const_iterator it
102                     = groups_by_num.find(i);
103                 if (it != groups_by_num.end()) 
104                 {   //it is
105                     std::string name = it->second;
106                     vars[name] = what[i];
107                 }
108
109             }
110             //prepare replacement string
111             std::string rvalue = sub_vars(uri_pat, vars);
112             //rewrite value
113             std::string rhvalue = what.prefix().str() 
114                 + rvalue + what.suffix().str();
115             header->value = odr_strdup(o, rhvalue.c_str());
116             std::cout << "! Rewritten '"+hvalue+"' to '"+rhvalue+"'\n";
117             start = what[0].second; //move search forward
118         }
119     };
120
121     static void parse_groups(const std::string & str,
122             std::map<int, std::string> & groups_bynum,
123             std::map<std::string, int> & groups_byname)
124     {
125        int gnum = 0;
126        bool esc = false;
127        for (int i = 0; i < str.size(); ++i)
128        {
129            if (!esc && str[i] == '\\')
130            {
131                esc = true;
132                continue;
133            }
134            if (!esc && str[i] == '(') //group starts
135            {
136                gnum++;
137                if (i+1 < str.size() && str[i+1] == '?') //group with attrs 
138                {
139                    i++;
140                    if (i+1 < str.size() && str[i+1] == 'P') //optional, python
141                        i++;
142                    if (i+1 < str.size() && str[i+1] == '<') //named
143                    {
144                        i++;
145                        std::string gname;
146                        bool term = false;
147                        while (++i < str.size())
148                        {
149                            if (str[i] == '>') { term = true; break; }
150                            if (!isalnum(str[i])) 
151                                throw mp::filter::FilterException
152                                    ("Only alphanumeric chars allowed, found "
153                                     " in '" 
154                                     + str 
155                                     + "' at " 
156                                     + boost::lexical_cast<std::string>(i)); 
157                            gname += str[i];
158                        }
159                        if (!term)
160                            throw mp::filter::FilterException
161                                ("Unterminated group name '" + gname 
162                                 + " in '" + str +"'");
163                       groups_bynum[gnum] = gname;
164                       groups_byname[gname] = gnum;
165                       std::cout << "Found named group '" << gname 
166                           << "' at $" << gnum << std::endl;
167                    }
168                }
169            }
170            esc = false;
171        }
172     }
173
174     static std::string sub_vars (const std::string & in, 
175             const std::map<std::string, std::string> & vars)
176     {
177         std::string out;
178         bool esc = false;
179         for (int i = 0; i < in.size(); ++i)
180         {
181             if (!esc && in[i] == '\\')
182             {
183                 esc = true;
184                 continue;
185             }
186             if (!esc && in[i] == '$') //var
187             {
188                 if (i+1 < in.size() && in[i+1] == '{') //ref prefix
189                 {
190                     ++i;
191                     std::string name;
192                     bool term = false;
193                     while (++i < in.size()) 
194                     {
195                         if (in[i] == '}') { term = true; break; }
196                         name += in[i];
197                     }
198                     if (!term) throw mp::filter::FilterException
199                         ("Unterminated var ref in '"+in+"' at "
200                          + boost::lexical_cast<std::string>(i));
201                     std::map<std::string, std::string>::const_iterator it
202                         = vars.find(name);
203                     if (it != vars.end())
204                         out += it->second;
205                 }
206                 else
207                 {
208                     throw mp::filter::FilterException
209                         ("Malformed or trimmed var ref in '"
210                          +in+"' at "+boost::lexical_cast<std::string>(i)); 
211                 }
212                 continue;
213             }
214             //passthru
215             out += in[i];
216             esc = false;
217         }
218         return out;
219     }
220     
221     void configure(
222             const std::string & req_uri_rx, 
223             const std::string & req_uri_pat, 
224             const std::string & resp_uri_rx, 
225             const std::string & resp_uri_pat) 
226     {
227        this->req_uri_rx = req_uri_rx;
228        this->req_uri_pat = req_uri_pat;
229        //pick up names
230        parse_groups(req_uri_rx, groups_by_num, groups_by_name);
231        this->resp_uri_rx = resp_uri_rx;
232        this->resp_uri_pat = resp_uri_pat;
233     };
234
235 private:
236     std::map<std::string, std::string> vars;
237     std::string req_uri_rx;
238     std::string resp_uri_rx;
239     std::string req_uri_pat;
240     std::string resp_uri_pat;
241     std::map<int, std::string> groups_by_num;
242     std::map<std::string, int> groups_by_name;
243
244 };
245
246
247 BOOST_AUTO_TEST_CASE( test_filter_rewrite_1 )
248 {
249     try
250     {
251        FilterHeaderRewrite fhr;
252     }
253     catch ( ... ) {
254         BOOST_CHECK (false);
255     }
256 }
257
258 BOOST_AUTO_TEST_CASE( test_filter_rewrite_2 )
259 {
260     try
261     {
262         mp::RouterChain router;
263
264         FilterHeaderRewrite fhr;
265         fhr.configure(
266                 "(?<host>[A-Za-z.]+):(?<port>\\d+)",
267                 "http://${host}:${port}/somepath",
268                 //rewrite connection close
269                 "close",
270                 "open");
271
272         mp::filter::HTTPClient hc;
273         
274         router.append(fhr);
275         router.append(hc);
276
277         // create an http request
278         mp::Package pack;
279
280         mp::odr odr;
281         Z_GDU *gdu_req = z_get_HTTP_Request_uri(odr, 
282             "http://localhost:80/~jakub/targetsite.php", 0, 1);
283
284         pack.request() = gdu_req;
285
286         //feed to the router
287         pack.router(router).move();
288
289         //analyze the response
290         Z_GDU *gdu_res = pack.response().get();
291         BOOST_CHECK(gdu_res);
292         BOOST_CHECK_EQUAL(gdu_res->which, Z_GDU_HTTP_Response);
293         
294         Z_HTTP_Response *hres = gdu_res->u.HTTP_Response;
295         BOOST_CHECK(hres);
296
297     }
298     catch (std::exception & e) {
299         std::cout << e.what();
300         BOOST_CHECK (false);
301     }
302 }
303
304 /*
305  * Local variables:
306  * c-basic-offset: 4
307  * c-file-style: "Stroustrup"
308  * indent-tabs-mode: nil
309  * End:
310  * vim: shiftwidth=4 tabstop=8 expandtab
311  */
312