1 : // Copyright 2014 Google Inc. All Rights Reserved.
2 : //
3 : // Licensed under the Apache License, Version 2.0 (the "License");
4 : // you may not use this file except in compliance with the License.
5 : // You may obtain a copy of the License at
6 : //
7 : // http://www.apache.org/licenses/LICENSE-2.0
8 : //
9 : // Unless required by applicable law or agreed to in writing, software
10 : // distributed under the License is distributed on an "AS IS" BASIS,
11 : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : // See the License for the specific language governing permissions and
13 : // limitations under the License.
14 :
15 : #include "syzygy/kasko/service_bridge.h"
16 :
17 : #include <Windows.h> // NOLINT
18 : #include <Rpc.h>
19 :
20 : #include <vector>
21 :
22 : #include "base/bind.h"
23 : #include "base/callback.h"
24 : #include "base/callback_helpers.h"
25 : #include "base/location.h"
26 : #include "base/macros.h"
27 : #include "base/message_loop/message_loop.h"
28 : #include "base/process/process_handle.h"
29 : #include "base/strings/string16.h"
30 : #include "base/strings/string_number_conversions.h"
31 : #include "base/synchronization/waitable_event.h"
32 : #include "base/threading/thread.h"
33 : #include "gtest/gtest.h"
34 : #include "syzygy/common/rpc/helpers.h"
35 : #include "syzygy/kasko/kasko_rpc.h"
36 : #include "syzygy/kasko/service.h"
37 : #include "syzygy/kasko/testing/mock_service.h"
38 :
39 : namespace kasko {
40 :
41 : namespace {
42 :
43 : const base::char16* const kValidRpcProtocol = L"ncalrpc";
44 : const base::char16* const kTestRpcEndpointPrefix = L"syzygy-kasko-test-svc";
45 :
46 E : base::string16 GetTestEndpoint() {
47 E : return kTestRpcEndpointPrefix + base::UintToString16(::GetCurrentProcessId());
48 E : }
49 :
50 : class BlockingService : public Service {
51 : public:
52 : BlockingService(base::WaitableEvent* release_call,
53 : base::WaitableEvent* blocking);
54 : virtual ~BlockingService();
55 :
56 : // Service implementation
57 : virtual void SendDiagnosticReport(
58 : base::ProcessId client_process_id,
59 : uint64_t exception_info_address,
60 : base::PlatformThreadId thread_id,
61 : MinidumpType minidump_type,
62 : const char* protobuf,
63 : size_t protobuf_length,
64 : const std::map<base::string16, base::string16>& crash_keys) override;
65 :
66 : private:
67 : base::WaitableEvent* release_call_;
68 : base::WaitableEvent* blocking_;
69 : DISALLOW_COPY_AND_ASSIGN(BlockingService);
70 : };
71 :
72 : BlockingService::BlockingService(base::WaitableEvent* release_call,
73 : base::WaitableEvent* blocking)
74 E : : release_call_(release_call), blocking_(blocking) {}
75 :
76 E : BlockingService::~BlockingService() {}
77 :
78 : void BlockingService::SendDiagnosticReport(
79 : base::ProcessId client_process_id,
80 : uint64_t exception_info_address,
81 : base::PlatformThreadId thread_id,
82 : MinidumpType minidump_type,
83 : const char* protobuf,
84 : size_t protobuf_length,
85 E : const std::map<base::string16, base::string16>& crash_keys) {
86 E : blocking_->Signal();
87 E : release_call_->Wait();
88 E : }
89 :
90 : void InvokeAndCheckRpcStatus(const base::Callback<RPC_STATUS(void)>& callback) {
91 : ASSERT_EQ(RPC_S_OK, callback.Run());
92 : }
93 :
94 : base::Closure WrapRpcStatusCallback(
95 : const base::Callback<RPC_STATUS(void)>& callback) {
96 : return base::Bind(InvokeAndCheckRpcStatus, callback);
97 : }
98 :
99 : void DoInvokeService(const base::string16& protocol,
100 : const base::string16& endpoint,
101 : const std::string& protobuf,
102 : bool* complete,
103 : size_t crash_keys_length,
104 E : const CrashKey* crash_keys) {
105 E : common::rpc::ScopedRpcBinding rpc_binding;
106 E : ASSERT_TRUE(rpc_binding.Open(protocol, endpoint));
107 :
108 : common::rpc::RpcStatus status = common::rpc::InvokeRpc(
109 : KaskoClient_SendDiagnosticReport, rpc_binding.Get(), NULL, 0, SMALL_DUMP,
110 : protobuf.length(), reinterpret_cast<const signed char*>(protobuf.c_str()),
111 E : crash_keys_length, crash_keys);
112 E : ASSERT_FALSE(status.exception_occurred);
113 E : ASSERT_TRUE(status.succeeded());
114 E : *complete = true;
115 E : }
116 :
117 : } // namespace
118 :
119 E : TEST(KaskoServiceBridgeTest, ConstructDestruct) {
120 E : std::vector<testing::MockService::CallRecord> call_log;
121 : {
122 : ServiceBridge instance(
123 : L"aaa", L"bbb",
124 E : scoped_ptr<Service>(new testing::MockService(&call_log)));
125 E : }
126 : {
127 : ServiceBridge instance(
128 : L"aaa", L"bbb",
129 E : scoped_ptr<Service>(new testing::MockService(&call_log)));
130 E : }
131 E : }
132 :
133 E : TEST(KaskoServiceBridgeTest, StopNonRunningInstance) {
134 E : std::vector<testing::MockService::CallRecord> call_log;
135 : ServiceBridge instance(
136 E : L"aaa", L"bbb", scoped_ptr<Service>(new testing::MockService(&call_log)));
137 E : instance.Stop();
138 E : }
139 :
140 E : TEST(KaskoServiceBridgeTest, FailToRunWithBadProtocol) {
141 E : std::vector<testing::MockService::CallRecord> call_log;
142 : {
143 : ServiceBridge instance(
144 : L"aaa", GetTestEndpoint(),
145 E : scoped_ptr<Service>(new testing::MockService(&call_log)));
146 E : ASSERT_FALSE(instance.Run());
147 : // Stop should not crash in this case.
148 E : instance.Stop();
149 E : }
150 E : }
151 :
152 E : TEST(KaskoServiceBridgeTest, RunSuccessfully) {
153 E : std::vector<testing::MockService::CallRecord> call_log;
154 :
155 : {
156 : ServiceBridge instance(
157 : kValidRpcProtocol, GetTestEndpoint(),
158 E : scoped_ptr<Service>(new testing::MockService(&call_log)));
159 E : ASSERT_TRUE(instance.Run());
160 E : instance.Stop();
161 :
162 : // Second run, same instance.
163 E : ASSERT_TRUE(instance.Run());
164 E : instance.Stop();
165 E : }
166 : {
167 : // Second instance.
168 : ServiceBridge instance(
169 : kValidRpcProtocol, GetTestEndpoint(),
170 E : scoped_ptr<Service>(new testing::MockService(&call_log)));
171 E : ASSERT_TRUE(instance.Run());
172 E : instance.Stop();
173 E : }
174 E : }
175 :
176 E : TEST(KaskoServiceBridgeTest, InvokeService) {
177 E : std::vector<testing::MockService::CallRecord> call_log;
178 :
179 E : base::string16 protocol = kValidRpcProtocol;
180 E : base::string16 endpoint = GetTestEndpoint();
181 : ServiceBridge instance(
182 : protocol, endpoint,
183 E : scoped_ptr<Service>(new testing::MockService(&call_log)));
184 E : ASSERT_TRUE(instance.Run());
185 :
186 : base::ScopedClosureRunner stop_service_bridge(
187 E : base::Bind(&ServiceBridge::Stop, base::Unretained(&instance)));
188 :
189 E : std::string protobuf = "hello world";
190 E : bool complete = false;
191 : CrashKey crash_keys[] = {{reinterpret_cast<const signed char*>("foo"),
192 : reinterpret_cast<const signed char*>("bar")},
193 : {reinterpret_cast<const signed char*>("hello"),
194 E : reinterpret_cast<const signed char*>("world")}};
195 :
196 : DoInvokeService(protocol, endpoint, protobuf, &complete,
197 E : arraysize(crash_keys), crash_keys);
198 E : ASSERT_TRUE(complete);
199 E : ASSERT_EQ(1u, call_log.size());
200 E : ASSERT_EQ(::GetCurrentProcessId(), call_log[0].client_process_id);
201 E : ASSERT_EQ(protobuf, call_log[0].protobuf);
202 E : ASSERT_EQ(2u, call_log[0].crash_keys.size());
203 E : auto entry = call_log[0].crash_keys.find(L"foo");
204 E : ASSERT_NE(call_log[0].crash_keys.end(), entry);
205 E : ASSERT_EQ(L"bar", entry->second);
206 E : entry = call_log[0].crash_keys.find(L"hello");
207 E : ASSERT_NE(call_log[0].crash_keys.end(), entry);
208 E : ASSERT_EQ(L"world", entry->second);
209 E : }
210 :
211 :
212 E : TEST(KaskoServiceBridgeTest, StopBlocksUntilCallsComplete) {
213 E : base::WaitableEvent release_call(false, false);
214 E : base::WaitableEvent blocking(false, false);
215 :
216 E : base::string16 protocol = kValidRpcProtocol;
217 E : base::string16 endpoint = GetTestEndpoint();
218 : ServiceBridge instance(
219 : protocol, endpoint,
220 E : scoped_ptr<Service>(new BlockingService(&release_call, &blocking)));
221 E : ASSERT_TRUE(instance.Run());
222 :
223 : base::ScopedClosureRunner stop_service_bridge(
224 E : base::Bind(&ServiceBridge::Stop, base::Unretained(&instance)));
225 : // In case an assertion fails, make sure that we will not block.
226 : base::ScopedClosureRunner signal_release_call(base::Bind(
227 E : &base::WaitableEvent::Signal, base::Unretained(&release_call)));
228 :
229 E : std::string protobuf = "hello world";
230 E : bool complete = false;
231 : CrashKey crash_keys[] = {{reinterpret_cast<const signed char*>("foo"),
232 : reinterpret_cast<const signed char*>("bar")},
233 : {reinterpret_cast<const signed char*>("hello"),
234 E : reinterpret_cast<const signed char*>("world")}};
235 :
236 E : base::Thread client_thread("client thread");
237 E : ASSERT_TRUE(client_thread.Start());
238 : client_thread.message_loop()->PostTask(
239 : FROM_HERE,
240 : base::Bind(&DoInvokeService, protocol, endpoint, protobuf,
241 : base::Unretained(&complete), arraysize(crash_keys),
242 E : base::Unretained(crash_keys)));
243 : // In case the DoInvokeService fails, let's make sure we unblock ourselves.
244 : client_thread.message_loop()->PostTask(
245 : FROM_HERE,
246 E : base::Bind(&base::WaitableEvent::Signal, base::Unretained(&blocking)));
247 E : blocking.Wait();
248 :
249 : // Either DoInvokeService failed (complete == true), or we are blocking in
250 : // BlockingService::SendDiagnosticReport (complete == false).
251 E : ASSERT_FALSE(complete);
252 :
253 : // Reduce the chance of false positives by giving the service call a chance to
254 : // complete. (It shouldn't.)
255 E : ::Sleep(100);
256 :
257 E : base::Thread stop_thread("stop thread");
258 E : ASSERT_TRUE(stop_thread.Start());
259 : stop_thread.message_loop()->PostTask(
260 E : FROM_HERE, base::Bind(&ServiceBridge::Stop, base::Unretained(&instance)));
261 E : ASSERT_FALSE(complete);
262 :
263 : // Stop is waiting for the pending call to complete. Let's unblock it now.
264 E : release_call.Signal();
265 :
266 : // This will not return until the ServiceBridge::Stop has completed.
267 E : stop_thread.Stop();
268 E : ASSERT_TRUE(complete);
269 E : }
270 :
271 : } // namespace kasko
|