1 : // Copyright 2012 Google Inc. All Rights Reserved.
2 : //
3 : // Licensed under the Apache License, Version 2.0 (the "License");
4 : // you may not use this file except in compliance with the License.
5 : // You may obtain a copy of the License at
6 : //
7 : // http://www.apache.org/licenses/LICENSE-2.0
8 : //
9 : // Unless required by applicable law or agreed to in writing, software
10 : // distributed under the License is distributed on an "AS IS" BASIS,
11 : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : // See the License for the specific language governing permissions and
13 : // limitations under the License.
14 :
15 : #include "syzygy/pe/dia_util.h"
16 :
17 : #include <diacreate.h>
18 :
19 : #include "base/logging.h"
20 : #include "base/win/scoped_bstr.h"
21 : #include "base/win/scoped_comptr.h"
22 : #include "syzygy/common/com_utils.h"
23 :
24 : namespace pe {
25 :
26 : using base::win::ScopedBstr;
27 : using base::win::ScopedComPtr;
28 :
29 : const wchar_t kDiaDllName[] = L"msdia120.dll";
30 :
31 : const wchar_t kFixupDiaDebugStreamName[] = L"FIXUP";
32 : const wchar_t kOmapToDiaDebugStreamName[] = L"OMAPTO";
33 : const wchar_t kOmapFromDiaDebugStreamName[] = L"OMAPFROM";
34 :
35 E : bool CreateDiaSource(IDiaDataSource** created_source) {
36 E : DCHECK(created_source != NULL);
37 :
38 E : *created_source = NULL;
39 :
40 E : ScopedComPtr<IDiaDataSource> dia_source;
41 E : HRESULT hr1 = dia_source.CreateInstance(CLSID_DiaSource);
42 E : if (SUCCEEDED(hr1)) {
43 E : *created_source = dia_source.Detach();
44 E : return true;
45 : }
46 :
47 : HRESULT hr2 = NoRegCoCreate(kDiaDllName,
48 : CLSID_DiaSource,
49 : IID_IDiaDataSource,
50 E : reinterpret_cast<void**>(&dia_source));
51 E : if (SUCCEEDED(hr2)) {
52 E : *created_source = dia_source.Detach();
53 E : return true;
54 : }
55 :
56 i : LOG(ERROR) << "Failed to create DiaDataSource.";
57 i : LOG(ERROR) << " CreateInstance failed with: " << common::LogHr(hr1);
58 i : LOG(ERROR) << " NoRegCoCreate failed with: " << common::LogHr(hr2);
59 :
60 i : return false;
61 E : }
62 :
63 : bool CreateDiaSession(const base::FilePath& file,
64 : IDiaDataSource* dia_source,
65 E : IDiaSession** dia_session) {
66 E : DCHECK(dia_source != NULL);
67 E : DCHECK(dia_session != NULL);
68 :
69 E : *dia_session = NULL;
70 :
71 E : HRESULT hr = E_FAIL;
72 :
73 E : if (file.Extension() == L".pdb") {
74 E : hr = dia_source->loadDataFromPdb(file.value().c_str());
75 E : } else {
76 E : hr = dia_source->loadDataForExe(file.value().c_str(), NULL, NULL);
77 : }
78 :
79 E : if (FAILED(hr)) {
80 i : LOG(ERROR) << "Failed to load DIA data for \"" << file.value() << "\": "
81 : << common::LogHr(hr) << ".";
82 i : return false;
83 : }
84 :
85 E : ScopedComPtr<IDiaSession> session;
86 E : hr = dia_source->openSession(session.Receive());
87 E : if (FAILED(hr)) {
88 i : LOG(ERROR) << "Failed to open DIA session for \"" << file.value() << "\" : "
89 : << common::LogHr(hr) << ".";
90 i : return false;
91 : }
92 :
93 E : *dia_session = session.Detach();
94 :
95 E : return true;
96 E : }
97 :
98 : SearchResult FindDiaTable(const IID& iid,
99 : IDiaSession* dia_session,
100 E : void** out_table) {
101 E : DCHECK(dia_session != NULL);
102 E : DCHECK(out_table != NULL);
103 :
104 E : *out_table = NULL;
105 :
106 : // Get the table enumerator.
107 E : base::win::ScopedComPtr<IDiaEnumTables> enum_tables;
108 E : HRESULT hr = dia_session->getEnumTables(enum_tables.Receive());
109 E : if (FAILED(hr)) {
110 i : LOG(ERROR) << "Failed to get DIA table enumerator: "
111 : << common::LogHr(hr) << ".";
112 i : return kSearchErrored;
113 : }
114 :
115 : // Iterate through the tables.
116 E : while (true) {
117 E : base::win::ScopedComPtr<IDiaTable> table;
118 E : ULONG fetched = 0;
119 E : hr = enum_tables->Next(1, table.Receive(), &fetched);
120 E : if (FAILED(hr)) {
121 i : LOG(ERROR) << "Failed to get DIA table: "
122 : << common::LogHr(hr) << ".";
123 i : return kSearchErrored;
124 : }
125 E : if (fetched == 0)
126 i : break;
127 :
128 E : hr = table.QueryInterface(iid, out_table);
129 E : if (SUCCEEDED(hr))
130 E : return kSearchSucceeded;
131 E : }
132 :
133 : // The search completed, even though we didn't find what we were looking for.
134 i : return kSearchFailed;
135 E : }
136 :
137 : SearchResult FindDiaDebugStream(const wchar_t* name,
138 : IDiaSession* dia_session,
139 E : IDiaEnumDebugStreamData** dia_debug_stream) {
140 E : DCHECK(name != NULL);
141 E : DCHECK(dia_session != NULL);
142 E : DCHECK(dia_debug_stream != NULL);
143 :
144 E : *dia_debug_stream = NULL;
145 :
146 E : HRESULT hr = E_FAIL;
147 E : ScopedComPtr<IDiaEnumDebugStreams> debug_streams;
148 E : if (FAILED(hr = dia_session->getEnumDebugStreams(debug_streams.Receive()))) {
149 i : LOG(ERROR) << "Unable to get debug streams: " << common::LogHr(hr) << ".";
150 i : return kSearchErrored;
151 : }
152 :
153 : // Iterate through the debug streams.
154 E : while (true) {
155 E : ScopedComPtr<IDiaEnumDebugStreamData> debug_stream;
156 E : ULONG count = 0;
157 E : HRESULT hr = debug_streams->Next(1, debug_stream.Receive(), &count);
158 E : if (FAILED(hr) || (hr != S_FALSE && count != 1)) {
159 i : LOG(ERROR) << "Unable to load debug stream: "
160 : << common::LogHr(hr) << ".";
161 i : return kSearchErrored;
162 E : } else if (hr == S_FALSE) {
163 : // No more records.
164 E : break;
165 : }
166 :
167 E : ScopedBstr stream_name;
168 E : if (FAILED(hr = debug_stream->get_name(stream_name.Receive()))) {
169 i : LOG(ERROR) << "Unable to get debug stream name: "
170 : << common::LogHr(hr) << ".";
171 i : return kSearchErrored;
172 : }
173 :
174 : // Found the stream?
175 E : if (wcscmp(common::ToString(stream_name), name) == 0) {
176 E : *dia_debug_stream = debug_stream.Detach();
177 E : return kSearchSucceeded;
178 : }
179 E : }
180 :
181 E : return kSearchFailed;
182 E : }
183 :
184 E : bool GetSymTag(IDiaSymbol* symbol, enum SymTagEnum* sym_tag) {
185 E : DCHECK(symbol != NULL);
186 E : DCHECK(sym_tag != NULL);
187 E : DWORD tmp_tag = SymTagNull;
188 E : *sym_tag = SymTagNull;
189 E : HRESULT hr = symbol->get_symTag(&tmp_tag);
190 E : if (hr != S_OK) {
191 i : LOG(ERROR) << "Error getting sym tag: " << common::LogHr(hr) << ".";
192 i : return false;
193 : }
194 E : *sym_tag = static_cast<enum SymTagEnum>(tmp_tag);
195 E : return true;
196 E : }
197 :
198 E : bool IsSymTag(IDiaSymbol* symbol, enum SymTagEnum expected_sym_tag) {
199 E : DCHECK(symbol != NULL);
200 E : DCHECK(expected_sym_tag != SymTagNull);
201 :
202 E : enum SymTagEnum sym_tag = SymTagNull;
203 E : if (!GetSymTag(symbol, &sym_tag))
204 i : return false;
205 :
206 E : return sym_tag == expected_sym_tag;
207 E : }
208 :
209 : ChildVisitor::ChildVisitor(IDiaSymbol* parent, enum SymTagEnum type)
210 E : : parent_(parent), type_(type), child_callback_(NULL) {
211 E : DCHECK(parent != NULL);
212 E : }
213 :
214 E : bool ChildVisitor::VisitChildren(const VisitSymbolCallback& child_callback) {
215 E : DCHECK(child_callback_ == NULL);
216 :
217 E : child_callback_ = &child_callback;
218 E : bool ret = VisitChildrenImpl();
219 E : child_callback_ = NULL;
220 :
221 E : return ret;
222 E : }
223 :
224 E : bool ChildVisitor::VisitChildrenImpl() {
225 E : DCHECK(child_callback_ != NULL);
226 :
227 : // Retrieve an enumerator for all children in this PDB.
228 E : base::win::ScopedComPtr<IDiaEnumSymbols> children;
229 : HRESULT hr = parent_->findChildren(type_,
230 : NULL,
231 : nsNone,
232 E : children.Receive());
233 E : if (FAILED(hr)) {
234 i : LOG(ERROR) << "Unable to get children: " << common::LogHr(hr);
235 i : return false;
236 : }
237 :
238 E : return EnumerateChildren(children);
239 E : }
240 :
241 E : bool ChildVisitor::EnumerateChildren(IDiaEnumSymbols* children) {
242 E : DCHECK(children!= NULL);
243 :
244 E : while (true) {
245 E : base::win::ScopedComPtr<IDiaSymbol> child;
246 E : ULONG fetched = 0;
247 E : HRESULT hr = children->Next(1, child.Receive(), &fetched);
248 E : if (FAILED(hr)) {
249 i : DCHECK_EQ(0U, fetched);
250 i : DCHECK(child == NULL);
251 i : LOG(ERROR) << "Unable to iterate children: " << common::LogHr(hr);
252 i : return false;
253 : }
254 E : if (hr == S_FALSE)
255 E : break;
256 :
257 E : DCHECK_EQ(1U, fetched);
258 E : DCHECK(child != NULL);
259 :
260 E : if (!VisitChild(child))
261 E : return false;
262 E : }
263 :
264 E : return true;
265 E : }
266 :
267 E : bool ChildVisitor::VisitChild(IDiaSymbol* child) {
268 E : DCHECK(child_callback_ != NULL);
269 :
270 E : return child_callback_->Run(child);
271 E : }
272 :
273 E : CompilandVisitor::CompilandVisitor(IDiaSession* session) : session_(session) {
274 E : DCHECK(session != NULL);
275 E : }
276 :
277 : bool CompilandVisitor::VisitAllCompilands(
278 E : const VisitCompilandCallback& compiland_callback) {
279 E : base::win::ScopedComPtr<IDiaSymbol> global;
280 E : HRESULT hr = session_->get_globalScope(global.Receive());
281 E : if (FAILED(hr)) {
282 i : LOG(ERROR) << "Unable to get global scope: " << common::LogHr(hr);
283 i : return false;
284 : }
285 :
286 E : ChildVisitor visitor(global, SymTagCompiland);
287 :
288 E : return visitor.VisitChildren(compiland_callback);
289 E : }
290 :
291 : LineVisitor::LineVisitor(IDiaSession* session, IDiaSymbol* compiland)
292 E : : session_(session), compiland_(compiland), line_callback_(NULL) {
293 E : DCHECK(session != NULL);
294 E : }
295 :
296 E : bool LineVisitor::VisitLines(const VisitLineCallback& line_callback) {
297 E : DCHECK(line_callback_ == NULL);
298 :
299 E : line_callback_ = &line_callback;
300 E : bool ret = VisitLinesImpl();
301 E : line_callback_ = NULL;
302 :
303 E : return ret;
304 E : }
305 :
306 : bool LineVisitor::EnumerateCompilandSource(IDiaSymbol* compiland,
307 E : IDiaSourceFile* source_file) {
308 E : DCHECK(compiland != NULL);
309 E : DCHECK(source_file != NULL);
310 :
311 E : base::win::ScopedComPtr<IDiaEnumLineNumbers> line_numbers;
312 : HRESULT hr = session_->findLines(compiland,
313 : source_file,
314 E : line_numbers.Receive());
315 E : if (FAILED(hr)) {
316 : // This seems to happen for the occasional header file.
317 i : return true;
318 : }
319 :
320 E : while (true) {
321 E : base::win::ScopedComPtr<IDiaLineNumber> line_number;
322 E : ULONG fetched = 0;
323 E : hr = line_numbers->Next(1, line_number.Receive(), &fetched);
324 E : if (FAILED(hr)) {
325 i : DCHECK_EQ(0U, fetched);
326 i : DCHECK(line_number == NULL);
327 i : LOG(ERROR) << "Unable to iterate line numbers: " << common::LogHr(hr);
328 i : return false;
329 : }
330 E : if (hr == S_FALSE)
331 E : break;
332 :
333 E : DCHECK_EQ(1U, fetched);
334 E : DCHECK(line_number != NULL);
335 :
336 E : if (!VisitSourceLine(line_number))
337 i : return false;
338 E : }
339 :
340 E : return true;
341 E : }
342 :
343 : bool LineVisitor::EnumerateCompilandSources(IDiaSymbol* compiland,
344 E : IDiaEnumSourceFiles* source_files) {
345 E : DCHECK(compiland != NULL);
346 E : DCHECK(source_files != NULL);
347 :
348 E : while (true) {
349 E : base::win::ScopedComPtr<IDiaSourceFile> source_file;
350 E : ULONG fetched = 0;
351 E : HRESULT hr = source_files->Next(1, source_file.Receive(), &fetched);
352 E : if (FAILED(hr)) {
353 i : DCHECK_EQ(0U, fetched);
354 i : DCHECK(source_file == NULL);
355 i : LOG(ERROR) << "Unable to iterate source files: " << common::LogHr(hr);
356 i : return false;
357 : }
358 E : if (hr == S_FALSE)
359 E : break;
360 :
361 E : DCHECK_EQ(1U, fetched);
362 E : DCHECK(compiland != NULL);
363 :
364 E : if (!EnumerateCompilandSource(compiland, source_file))
365 i : return false;
366 E : }
367 :
368 E : return true;
369 E : }
370 :
371 E : bool LineVisitor::VisitLinesImpl() {
372 E : DCHECK(session_ != NULL);
373 E : DCHECK(compiland_ != NULL);
374 E : DCHECK(line_callback_ != NULL);
375 :
376 : // Enumerate all source files referenced by this compiland.
377 E : base::win::ScopedComPtr<IDiaEnumSourceFiles> source_files;
378 : HRESULT hr = session_->findFile(compiland_,
379 : NULL,
380 : nsNone,
381 E : source_files.Receive());
382 E : if (FAILED(hr)) {
383 i : LOG(ERROR) << "Unable to get source files: " << common::LogHr(hr);
384 i : return false;
385 : }
386 :
387 E : return EnumerateCompilandSources(compiland_, source_files);
388 E : }
389 :
390 E : bool LineVisitor::VisitSourceLine(IDiaLineNumber* line_number) {
391 E : DCHECK(line_callback_ != NULL);
392 :
393 E : return line_callback_->Run(line_number);
394 E : }
395 :
396 : } // namespace pe
|