1 : // Copyright 2014 Google Inc. All Rights Reserved.
2 : //
3 : // Licensed under the Apache License, Version 2.0 (the "License");
4 : // you may not use this file except in compliance with the License.
5 : // You may obtain a copy of the License at
6 : //
7 : // http://www.apache.org/licenses/LICENSE-2.0
8 : //
9 : // Unless required by applicable law or agreed to in writing, software
10 : // distributed under the License is distributed on an "AS IS" BASIS,
11 : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : // See the License for the specific language governing permissions and
13 : // limitations under the License.
14 :
15 : #include "syzygy/kasko/service_bridge.h"
16 :
17 : #include <Windows.h> // NOLINT
18 : #include <Rpc.h>
19 :
20 : #include <vector>
21 :
22 : #include "base/bind.h"
23 : #include "base/callback.h"
24 : #include "base/callback_helpers.h"
25 : #include "base/location.h"
26 : #include "base/macros.h"
27 : #include "base/message_loop/message_loop.h"
28 : #include "base/process/process_handle.h"
29 : #include "base/strings/string16.h"
30 : #include "base/strings/string_number_conversions.h"
31 : #include "base/synchronization/waitable_event.h"
32 : #include "base/threading/thread.h"
33 : #include "gtest/gtest.h"
34 : #include "syzygy/common/rpc/helpers.h"
35 : #include "syzygy/kasko/kasko_rpc.h"
36 : #include "syzygy/kasko/service.h"
37 : #include "syzygy/kasko/testing/mock_service.h"
38 :
39 m : namespace kasko {
40 :
41 m : namespace {
42 :
43 m : const base::char16* const kValidRpcProtocol = L"ncalrpc";
44 m : const base::char16* const kTestRpcEndpointPrefix = L"syzygy-kasko-test-svc";
45 :
46 m : base::string16 GetTestEndpoint() {
47 m : return kTestRpcEndpointPrefix + base::UintToString16(::GetCurrentProcessId());
48 m : }
49 :
50 m : class BlockingService : public Service {
51 m : public:
52 m : BlockingService(base::WaitableEvent* release_call,
53 m : base::WaitableEvent* blocking);
54 m : virtual ~BlockingService();
55 :
56 : // Service implementation
57 m : virtual void SendDiagnosticReport(
58 m : base::ProcessId client_process_id,
59 m : base::PlatformThreadId thread_id,
60 m : const MinidumpRequest& request) override;
61 :
62 m : private:
63 m : base::WaitableEvent* release_call_;
64 m : base::WaitableEvent* blocking_;
65 m : DISALLOW_COPY_AND_ASSIGN(BlockingService);
66 m : };
67 :
68 m : BlockingService::BlockingService(base::WaitableEvent* release_call,
69 m : base::WaitableEvent* blocking)
70 m : : release_call_(release_call), blocking_(blocking) {}
71 :
72 m : BlockingService::~BlockingService() {}
73 :
74 m : void BlockingService::SendDiagnosticReport(
75 m : base::ProcessId client_process_id,
76 m : base::PlatformThreadId thread_id,
77 m : const MinidumpRequest& request) {
78 m : blocking_->Signal();
79 m : release_call_->Wait();
80 m : }
81 :
82 m : void InvokeAndCheckRpcStatus(const base::Callback<RPC_STATUS(void)>& callback) {
83 m : ASSERT_EQ(RPC_S_OK, callback.Run());
84 m : }
85 :
86 m : base::Closure WrapRpcStatusCallback(
87 m : const base::Callback<RPC_STATUS(void)>& callback) {
88 m : return base::Bind(InvokeAndCheckRpcStatus, callback);
89 m : }
90 :
91 m : void DoInvokeService(const base::string16& protocol,
92 m : const base::string16& endpoint,
93 m : bool* complete,
94 m : long exception_info_address,
95 m : long thread_id,
96 m : DumpType dump_type,
97 m : size_t memory_ranges_length,
98 m : const MemoryRange* memory_ranges,
99 m : size_t crash_keys_length,
100 m : const CrashKey* crash_keys,
101 m : size_t custom_streams_length,
102 m : const CustomStream* custom_streams) {
103 m : common::rpc::ScopedRpcBinding rpc_binding;
104 m : ASSERT_TRUE(rpc_binding.Open(protocol, endpoint));
105 :
106 m : ::MinidumpRequest rpc_request = {exception_info_address,
107 m : thread_id,
108 m : dump_type,
109 m : memory_ranges_length,
110 m : memory_ranges,
111 m : crash_keys_length,
112 m : crash_keys,
113 m : custom_streams_length,
114 m : custom_streams};
115 :
116 m : common::rpc::RpcStatus status = common::rpc::InvokeRpc(
117 m : KaskoClient_SendDiagnosticReport, rpc_binding.Get(), rpc_request);
118 m : ASSERT_FALSE(status.exception_occurred);
119 m : ASSERT_TRUE(status.succeeded());
120 m : *complete = true;
121 m : }
122 :
123 m : } // namespace
124 :
125 m : TEST(KaskoServiceBridgeTest, ConstructDestruct) {
126 m : std::vector<testing::MockService::CallRecord> call_log;
127 m : {
128 m : ServiceBridge instance(
129 m : L"aaa", L"bbb",
130 m : std::unique_ptr<Service>(new testing::MockService(&call_log)));
131 m : }
132 m : {
133 m : ServiceBridge instance(
134 m : L"aaa", L"bbb",
135 m : std::unique_ptr<Service>(new testing::MockService(&call_log)));
136 m : }
137 m : }
138 :
139 m : TEST(KaskoServiceBridgeTest, StopNonRunningInstance) {
140 m : std::vector<testing::MockService::CallRecord> call_log;
141 m : ServiceBridge instance(
142 m : L"aaa", L"bbb",
143 m : std::unique_ptr<Service>(new testing::MockService(&call_log)));
144 m : instance.Stop();
145 m : }
146 :
147 m : TEST(KaskoServiceBridgeTest, FailToRunWithBadProtocol) {
148 m : std::vector<testing::MockService::CallRecord> call_log;
149 m : {
150 m : ServiceBridge instance(
151 m : L"aaa", GetTestEndpoint(),
152 m : std::unique_ptr<Service>(new testing::MockService(&call_log)));
153 m : ASSERT_FALSE(instance.Run());
154 : // Stop should not crash in this case.
155 m : instance.Stop();
156 m : }
157 m : }
158 :
159 m : TEST(KaskoServiceBridgeTest, RunSuccessfully) {
160 m : std::vector<testing::MockService::CallRecord> call_log;
161 :
162 m : {
163 m : ServiceBridge instance(
164 m : kValidRpcProtocol, GetTestEndpoint(),
165 m : std::unique_ptr<Service>(new testing::MockService(&call_log)));
166 m : ASSERT_TRUE(instance.Run());
167 m : instance.Stop();
168 :
169 : // Second run, same instance.
170 m : ASSERT_TRUE(instance.Run());
171 m : instance.Stop();
172 m : }
173 m : {
174 : // Second instance.
175 m : ServiceBridge instance(
176 m : kValidRpcProtocol, GetTestEndpoint(),
177 m : std::unique_ptr<Service>(new testing::MockService(&call_log)));
178 m : ASSERT_TRUE(instance.Run());
179 m : instance.Stop();
180 m : }
181 m : }
182 :
183 m : TEST(KaskoServiceBridgeTest, InvokeService) {
184 m : std::vector<testing::MockService::CallRecord> call_log;
185 :
186 m : base::string16 protocol = kValidRpcProtocol;
187 m : base::string16 endpoint = GetTestEndpoint();
188 m : ServiceBridge instance(
189 m : protocol, endpoint,
190 m : std::unique_ptr<Service>(new testing::MockService(&call_log)));
191 m : ASSERT_TRUE(instance.Run());
192 :
193 m : base::ScopedClosureRunner stop_service_bridge(
194 m : base::Bind(&ServiceBridge::Stop, base::Unretained(&instance)));
195 :
196 :
197 m : std::string stream_data = "hello world";
198 m : uint32_t kStreamType = 987;
199 m : CustomStream custom_streams[] = {
200 m : {kStreamType, stream_data.length(),
201 m : reinterpret_cast<const signed char*>(stream_data.data())}};
202 m : bool complete = false;
203 m : CrashKey crash_keys[] = {{reinterpret_cast<const wchar_t*>(L"foo"),
204 m : reinterpret_cast<const wchar_t*>(L"bar")},
205 m : {reinterpret_cast<const wchar_t*>(L"hello"),
206 m : reinterpret_cast<const wchar_t*>(L"world")}};
207 :
208 m : MemoryRange memory_ranges[] = {{0xdeadbeef, 123}};
209 :
210 m : DoInvokeService(protocol, endpoint, &complete, 0, 0, SMALL_DUMP,
211 m : arraysize(memory_ranges), memory_ranges,
212 m : arraysize(crash_keys), crash_keys, arraysize(custom_streams),
213 m : custom_streams);
214 m : ASSERT_TRUE(complete);
215 m : complete = false;
216 m : DoInvokeService(protocol, endpoint, &complete, 1122, 3, LARGER_DUMP, 0,
217 m : nullptr, 0, nullptr, 0, nullptr);
218 m : ASSERT_TRUE(complete);
219 :
220 m : ASSERT_EQ(2u, call_log.size());
221 :
222 : // First request
223 m : ASSERT_EQ(::GetCurrentProcessId(), call_log[0].client_process_id);
224 m : ASSERT_EQ(0, call_log[0].exception_info_address);
225 m : ASSERT_EQ(0, call_log[0].thread_id);
226 :
227 m : ASSERT_EQ(1u, call_log[0].user_selected_memory_ranges.size());
228 m : ASSERT_EQ(memory_ranges[0].base_address,
229 m : call_log[0].user_selected_memory_ranges[0].start());
230 m : ASSERT_EQ(memory_ranges[0].length,
231 m : call_log[0].user_selected_memory_ranges[0].size());
232 :
233 m : ASSERT_EQ(1u, call_log[0].custom_streams.size());
234 m : auto custom_streams_entry = call_log[0].custom_streams.find(kStreamType);
235 m : ASSERT_NE(call_log[0].custom_streams.end(), custom_streams_entry);
236 m : ASSERT_EQ(stream_data, custom_streams_entry->second);
237 :
238 m : ASSERT_EQ(2u, call_log[0].crash_keys.size());
239 m : auto crash_keys_entry = call_log[0].crash_keys.find(L"foo");
240 m : ASSERT_NE(call_log[0].crash_keys.end(), crash_keys_entry);
241 m : ASSERT_EQ(L"bar", crash_keys_entry->second);
242 m : crash_keys_entry = call_log[0].crash_keys.find(L"hello");
243 m : ASSERT_NE(call_log[0].crash_keys.end(), crash_keys_entry);
244 m : ASSERT_EQ(L"world", crash_keys_entry->second);
245 :
246 : // Second request
247 m : ASSERT_EQ(::GetCurrentProcessId(), call_log[1].client_process_id);
248 m : ASSERT_EQ(1122, call_log[1].exception_info_address);
249 m : ASSERT_EQ(3, call_log[1].thread_id);
250 m : ASSERT_EQ(0u, call_log[1].custom_streams.size());
251 m : ASSERT_EQ(0u, call_log[1].crash_keys.size());
252 m : }
253 :
254 :
255 m : TEST(KaskoServiceBridgeTest, StopBlocksUntilCallsComplete) {
256 m : base::WaitableEvent release_call(false, false);
257 m : base::WaitableEvent blocking(false, false);
258 :
259 m : base::string16 protocol = kValidRpcProtocol;
260 m : base::string16 endpoint = GetTestEndpoint();
261 m : ServiceBridge instance(
262 m : protocol, endpoint,
263 m : std::unique_ptr<Service>(new BlockingService(&release_call, &blocking)));
264 m : ASSERT_TRUE(instance.Run());
265 :
266 m : base::ScopedClosureRunner stop_service_bridge(
267 m : base::Bind(&ServiceBridge::Stop, base::Unretained(&instance)));
268 : // In case an assertion fails, make sure that we will not block.
269 m : base::ScopedClosureRunner signal_release_call(base::Bind(
270 m : &base::WaitableEvent::Signal, base::Unretained(&release_call)));
271 :
272 m : bool complete = false;
273 m : CrashKey crash_keys[] = {{reinterpret_cast<const wchar_t*>(L"foo"),
274 m : reinterpret_cast<const wchar_t*>(L"bar")},
275 m : {reinterpret_cast<const wchar_t*>(L"hello"),
276 m : reinterpret_cast<const wchar_t*>(L"world")}};
277 :
278 m : base::Thread client_thread("client thread");
279 m : ASSERT_TRUE(client_thread.Start());
280 m : client_thread.message_loop()->PostTask(
281 m : FROM_HERE, base::Bind(&DoInvokeService, protocol, endpoint,
282 m : base::Unretained(&complete), 0, 0, SMALL_DUMP, 0,
283 m : nullptr, arraysize(crash_keys),
284 m : base::Unretained(crash_keys), 0, nullptr));
285 : // In case the DoInvokeService fails, let's make sure we unblock ourselves.
286 m : client_thread.message_loop()->PostTask(
287 m : FROM_HERE,
288 m : base::Bind(&base::WaitableEvent::Signal, base::Unretained(&blocking)));
289 m : blocking.Wait();
290 :
291 : // Either DoInvokeService failed (complete == true), or we are blocking in
292 : // BlockingService::SendDiagnosticReport (complete == false).
293 m : ASSERT_FALSE(complete);
294 :
295 : // Reduce the chance of false positives by giving the service call a chance to
296 : // complete. (It shouldn't.)
297 m : ::Sleep(100);
298 :
299 m : base::Thread stop_thread("stop thread");
300 m : ASSERT_TRUE(stop_thread.Start());
301 m : stop_thread.message_loop()->PostTask(
302 m : FROM_HERE, base::Bind(&ServiceBridge::Stop, base::Unretained(&instance)));
303 m : ASSERT_FALSE(complete);
304 :
305 : // Stop is waiting for the pending call to complete. Let's unblock it now.
306 m : release_call.Signal();
307 :
308 : // This will not return until the ServiceBridge::Stop has completed.
309 m : stop_thread.Stop();
310 m : ASSERT_TRUE(complete);
311 m : }
312 :
313 m : } // namespace kasko
|